From 4b6ec248a6afb35c41c45578529b03293870ace2 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Thu, 26 Oct 2023 11:37:41 -0700 Subject: [PATCH] [Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511) - Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration. - There does _seem_ to be some regression in chat quality, which is most likely attributable to search results. This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI. --- pytest.ini | 2 + src/database/adapters/__init__.py | 145 +++++++++++++-- ...onprocessorconfig_conversation_and_more.py | 81 ++++++++ ...008_alter_conversation_conversation_log.py | 17 ++ src/database/models/__init__.py | 22 ++- src/khoj/configure.py | 110 +---------- src/khoj/interface/web/config.html | 13 +- src/khoj/routers/api.py | 175 ++++++++---------- src/khoj/routers/helpers.py | 87 ++++++--- src/khoj/routers/indexer.py | 27 +-- src/khoj/routers/web_client.py | 54 +++--- src/khoj/search_type/image_search.py | 1 + src/khoj/search_type/text_search.py | 56 +++--- src/khoj/utils/config.py | 39 +--- src/khoj/utils/constants.py | 122 ------------ src/khoj/utils/helpers.py | 23 +++ src/khoj/utils/rawconfig.py | 19 +- src/khoj/utils/state.py | 4 +- tests/conftest.py | 134 ++++++-------- tests/helpers.py | 51 +++++ tests/test_client.py | 5 + tests/test_gpt4all_chat_director.py | 76 +++++--- tests/test_openai_chat_director.py | 77 +++++--- tests/test_text_search.py | 5 +- 24 files changed, 719 insertions(+), 626 deletions(-) create mode 100644 src/database/migrations/0007_remove_conversationprocessorconfig_conversation_and_more.py create mode 100644 src/database/migrations/0008_alter_conversation_conversation_log.py create mode 100644 tests/helpers.py diff --git a/pytest.ini b/pytest.ini index eec111ec..b3e418d0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,5 @@ DJANGO_SETTINGS_MODULE = app.settings pythonpath = . src testpaths = tests +markers = + chatquality: marks tests as chatquality (deselect with '-m "not chatquality"') diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index fc4f23b1..db5e9f77 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,5 +1,4 @@ from typing import Type, TypeVar, List -import uuid from datetime import date from django.db import models @@ -21,6 +20,13 @@ from database.models import ( GithubConfig, Embeddings, GithubRepoConfig, + Conversation, + ConversationProcessorConfig, + OpenAIProcessorConversationConfig, + OfflineChatProcessorConversationConfig, +) +from khoj.utils.rawconfig import ( + ConversationProcessorConfig as UserConversationProcessorConfig, ) from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter @@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser: async def create_google_user(token: dict) -> KhojUser: - user_info = token.get("userinfo") - user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email")) + user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email")) await user.asave() await GoogleUser.objects.acreate( - sub=user_info.get("sub"), - azp=user_info.get("azp"), - email=user_info.get("email"), - name=user_info.get("name"), - given_name=user_info.get("given_name"), - family_name=user_info.get("family_name"), - picture=user_info.get("picture"), - locale=user_info.get("locale"), + sub=token.get("sub"), + azp=token.get("azp"), + email=token.get("email"), + name=token.get("name"), + given_name=token.get("given_name"), + family_name=token.get("family_name"), + picture=token.get("picture"), + locale=token.get("locale"), user=user, ) @@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config +class ConversationAdapters: + @staticmethod + def get_conversation_by_user(user: KhojUser): + conversation = Conversation.objects.filter(user=user) + if conversation.exists(): + return conversation.first() + return Conversation.objects.create(user=user) + + @staticmethod + async def aget_conversation_by_user(user: KhojUser): + conversation = Conversation.objects.filter(user=user) + if await conversation.aexists(): + return await conversation.afirst() + return await Conversation.objects.acreate(user=user) + + @staticmethod + def has_any_conversation_config(user: KhojUser): + return ConversationProcessorConfig.objects.filter(user=user).exists() + + @staticmethod + def get_openai_conversation_config(user: KhojUser): + return OpenAIProcessorConversationConfig.objects.filter(user=user).first() + + @staticmethod + def get_offline_chat_conversation_config(user: KhojUser): + return OfflineChatProcessorConversationConfig.objects.filter(user=user).first() + + @staticmethod + def has_valid_offline_conversation_config(user: KhojUser): + return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists() + + @staticmethod + def has_valid_openai_conversation_config(user: KhojUser): + return OpenAIProcessorConversationConfig.objects.filter(user=user).exists() + + @staticmethod + def get_conversation_config(user: KhojUser): + return ConversationProcessorConfig.objects.filter(user=user).first() + + @staticmethod + def save_conversation(user: KhojUser, conversation_log: dict): + conversation = Conversation.objects.filter(user=user) + if conversation.exists(): + conversation.update(conversation_log=conversation_log) + else: + Conversation.objects.create(user=user, conversation_log=conversation_log) + + @staticmethod + def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig): + conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user) + conversation_config.max_prompt_size = new_config.max_prompt_size + conversation_config.tokenizer = new_config.tokenizer + conversation_config.save() + + if new_config.openai: + default_values = { + "api_key": new_config.openai.api_key, + } + if new_config.openai.chat_model: + default_values["chat_model"] = new_config.openai.chat_model + + OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values) + + if new_config.offline_chat: + default_values = { + "enable_offline_chat": str(new_config.offline_chat.enable_offline_chat), + } + + if new_config.offline_chat.chat_model: + default_values["chat_model"] = new_config.offline_chat.chat_model + + OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values) + + @staticmethod + def get_enabled_conversation_settings(user: KhojUser): + openai_config = ConversationAdapters.get_openai_conversation_config(user) + offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user) + + return { + "openai": True if openai_config is not None else False, + "offline_chat": True + if (offline_chat_config is not None and offline_chat_config.enable_offline_chat) + else False, + } + + @staticmethod + def clear_conversation_config(user: KhojUser): + ConversationProcessorConfig.objects.filter(user=user).delete() + ConversationAdapters.clear_openai_conversation_config(user) + ConversationAdapters.clear_offline_chat_conversation_config(user) + + @staticmethod + def clear_openai_conversation_config(user: KhojUser): + OpenAIProcessorConversationConfig.objects.filter(user=user).delete() + + @staticmethod + def clear_offline_chat_conversation_config(user: KhojUser): + OfflineChatProcessorConversationConfig.objects.filter(user=user).delete() + + @staticmethod + async def has_offline_chat(user: KhojUser): + return await OfflineChatProcessorConversationConfig.objects.filter( + user=user, enable_offline_chat=True + ).aexists() + + @staticmethod + async def get_offline_chat(user: KhojUser): + return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst() + + @staticmethod + async def has_openai_chat(user: KhojUser): + return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists() + + @staticmethod + async def get_openai_chat(user: KhojUser): + return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst() + + class EmbeddingsAdapters: word_filer = WordFilter() file_filter = FileFilter() diff --git a/src/database/migrations/0007_remove_conversationprocessorconfig_conversation_and_more.py b/src/database/migrations/0007_remove_conversationprocessorconfig_conversation_and_more.py new file mode 100644 index 00000000..d66b2bd0 --- /dev/null +++ b/src/database/migrations/0007_remove_conversationprocessorconfig_conversation_and_more.py @@ -0,0 +1,81 @@ +# Generated by Django 4.2.5 on 2023-10-18 05:31 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0006_embeddingsdates"), + ] + + operations = [ + migrations.RemoveField( + model_name="conversationprocessorconfig", + name="conversation", + ), + migrations.RemoveField( + model_name="conversationprocessorconfig", + name="enable_offline_chat", + ), + migrations.AddField( + model_name="conversationprocessorconfig", + name="max_prompt_size", + field=models.IntegerField(blank=True, default=None, null=True), + ), + migrations.AddField( + model_name="conversationprocessorconfig", + name="tokenizer", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AddField( + model_name="conversationprocessorconfig", + name="user", + field=models.ForeignKey( + default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL + ), + preserve_default=False, + ), + migrations.CreateModel( + name="OpenAIProcessorConversationConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("api_key", models.CharField(max_length=200)), + ("chat_model", models.CharField(max_length=200)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="OfflineChatProcessorConversationConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("enable_offline_chat", models.BooleanField(default=False)), + ("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.CreateModel( + name="Conversation", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("conversation_log", models.JSONField()), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/database/migrations/0008_alter_conversation_conversation_log.py b/src/database/migrations/0008_alter_conversation_conversation_log.py new file mode 100644 index 00000000..8c60489f --- /dev/null +++ b/src/database/migrations/0008_alter_conversation_conversation_log.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.5 on 2023-10-18 16:46 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0007_remove_conversationprocessorconfig_conversation_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="conversation", + name="conversation_log", + field=models.JSONField(default=dict), + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 9a50d94f..a9d41e0d 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) -class ConversationProcessorConfig(BaseModel): - conversation = models.JSONField() +class OpenAIProcessorConversationConfig(BaseModel): + api_key = models.CharField(max_length=200) + chat_model = models.CharField(max_length=200) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class OfflineChatProcessorConversationConfig(BaseModel): enable_offline_chat = models.BooleanField(default=False) + chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin") + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class ConversationProcessorConfig(BaseModel): + max_prompt_size = models.IntegerField(default=None, null=True, blank=True) + tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + + +class Conversation(BaseModel): + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + conversation_log = models.JSONField(default=dict) class Embeddings(BaseModel): diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 76b2e9f4..ac43f9b4 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -23,12 +23,10 @@ from starlette.authentication import ( from khoj.utils import constants, state from khoj.utils.config import ( SearchType, - ProcessorConfigModel, - ConversationProcessorConfigModel, ) -from khoj.utils.helpers import resolve_absolute_path, merge_dicts +from khoj.utils.helpers import merge_dicts from khoj.utils.fs_syncer import collect_files -from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig +from khoj.utils.rawconfig import FullConfig from khoj.routers.indexer import configure_content, load_content, configure_search from database.models import KhojUser from database.adapters import get_all_users @@ -98,13 +96,6 @@ def configure_server( # Update Config state.config = config - # Initialize Processor from Config - try: - state.processor_config = configure_processor(state.config.processor) - except Exception as e: - logger.error(f"🚨 Failed to configure processor", exc_info=True) - raise e - # Initialize Search Models from Config and initialize content try: state.config_lock.acquire() @@ -190,103 +181,6 @@ def configure_search_types(config: FullConfig): return Enum("SearchType", merge_dicts(core_search_types, {})) -def configure_processor( - processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None -): - if not processor_config: - logger.warning("🚨 No Processor configuration available.") - return None - - processor = ProcessorConfigModel() - - # Initialize Conversation Processor - logger.info("💬 Setting up conversation processor") - processor.conversation = configure_conversation_processor(processor_config, state_processor_config) - - return processor - - -def configure_conversation_processor( - processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None -): - if ( - not processor_config - or not processor_config.conversation - or not processor_config.conversation.conversation_logfile - ): - default_config = constants.default_config - default_conversation_logfile = resolve_absolute_path( - default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore - ) - conversation_logfile = resolve_absolute_path(default_conversation_logfile) - conversation_config = processor_config.conversation if processor_config else None - conversation_processor = ConversationProcessorConfigModel( - conversation_config=ConversationProcessorConfig( - conversation_logfile=conversation_logfile, - openai=(conversation_config.openai if (conversation_config is not None) else None), - offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(), - max_prompt_size=conversation_config.max_prompt_size if conversation_config else None, - tokenizer=conversation_config.tokenizer if conversation_config else None, - ) - ) - else: - conversation_processor = ConversationProcessorConfigModel( - conversation_config=processor_config.conversation, - ) - conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile) - - # Load Conversation Logs from Disk - if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log: - conversation_processor.meta_log = state_processor_config.conversation.meta_log - conversation_processor.chat_session = state_processor_config.conversation.chat_session - logger.debug(f"Loaded conversation logs from state") - return conversation_processor - - if conversation_logfile.is_file(): - # Load Metadata Logs from Conversation Logfile - with conversation_logfile.open("r") as f: - conversation_processor.meta_log = json.load(f) - logger.debug(f"Loaded conversation logs from {conversation_logfile}") - else: - # Initialize Conversation Logs - conversation_processor.meta_log = {} - conversation_processor.chat_session = [] - - return conversation_processor - - -@schedule.repeat(schedule.every(17).minutes) -def save_chat_session(): - # No need to create empty log file - if not ( - state.processor_config - and state.processor_config.conversation - and state.processor_config.conversation.meta_log - and state.processor_config.conversation.chat_session - ): - return - - # Summarize Conversation Logs for this Session - conversation_log = state.processor_config.conversation.meta_log - session = { - "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], - "session-end": len(conversation_log["chat"]), - } - if "session" in conversation_log: - conversation_log["session"].append(session) - else: - conversation_log["session"] = [session] - - # Save Conversation Metadata Logs to Disk - conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile) - conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist - with open(conversation_logfile, "w+", encoding="utf-8") as logfile: - json.dump(conversation_log, logfile, indent=2) - - state.processor_config.conversation.chat_session = [] - logger.info("📩 Saved current chat session to conversation logs") - - @schedule.repeat(schedule.every(59).minutes) def upload_telemetry(): if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry: diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 6c69c056..979bf56c 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -3,6 +3,11 @@
+ {% if anonymous_mode == False %} +
+ Logged in as {{ username }} +
+ {% endif %}

Plugins

@@ -257,8 +262,8 @@ Chat

Offline Chat - Configured - {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %} + Configured + {% if current_model_state.enable_offline_model and not current_model_state.conversation_gpt4all %} Not Configured {% endif %}

@@ -266,12 +271,12 @@

Setup offline chat

-
+
-
+
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d041fd76..b7ba66b6 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -8,21 +8,20 @@ from typing import List, Optional, Union, Any import asyncio # External Packages -from fastapi import APIRouter, HTTPException, Header, Request, Depends +from fastapi import APIRouter, HTTPException, Header, Request from starlette.authentication import requires from asgiref.sync import sync_to_async # Internal Packages -from khoj.configure import configure_processor, configure_server +from khoj.configure import configure_server from khoj.search_type import image_search, text_search from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter -from khoj.utils.config import TextSearchModel +from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions from khoj.utils.rawconfig import ( FullConfig, - ProcessorConfig, SearchConfig, SearchResponse, TextContentConfig, @@ -32,16 +31,16 @@ from khoj.utils.rawconfig import ( ConversationProcessorConfig, OfflineChatProcessorConfig, ) -from khoj.utils.helpers import resolve_absolute_path from khoj.utils.state import SearchType from khoj.utils import state, constants -from khoj.utils.yaml import save_config_to_file_updated_state +from khoj.utils.helpers import AsyncIteratorWrapper from fastapi.responses import StreamingResponse, Response from khoj.routers.helpers import ( get_conversation_command, perform_chat_checks, - generate_chat_response, + agenerate_chat_response, update_telemetry_state, + is_ready_to_chat, ) from khoj.processor.conversation.prompts import help_message from khoj.processor.conversation.openai.gpt import extract_questions @@ -49,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off from fastapi.requests import Request from database import adapters -from database.adapters import EmbeddingsAdapters +from database.adapters import EmbeddingsAdapters, ConversationAdapters from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser @@ -114,6 +113,8 @@ async def map_config_to_db(config: FullConfig, user: KhojUser): user=user, token=config.content_type.notion.token, ) + if config.processor and config.processor.conversation: + ConversationAdapters.set_conversation_processor_config(user, config.processor.conversation) # If it's a demo instance, prevent updating any of the configuration. @@ -123,8 +124,6 @@ if not state.demo: if state.config is None: state.config = FullConfig() state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) - if state.processor_config is None: - state.processor_config = configure_processor(state.config.processor) @api.get("/config/data", response_model=FullConfig) @requires(["authenticated"], redirect="login_page") @@ -238,28 +237,24 @@ if not state.demo: ) content_object = map_config_to_object(content_type) + if content_object is None: + raise ValueError(f"Invalid content type: {content_type}") + await content_object.objects.filter(user=user).adelete() await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type) enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user) - return {"status": "ok"} @api.post("/delete/config/data/processor/conversation/openai", status_code=200) + @requires(["authenticated"], redirect="login_page") async def remove_processor_conversation_config_data( request: Request, client: Optional[str] = None, ): - if ( - not state.config - or not state.config.processor - or not state.config.processor.conversation - or not state.config.processor.conversation.openai - ): - return {"status": "ok"} + user = request.user.object - state.config.processor.conversation.openai = None - state.processor_config = configure_processor(state.config.processor, state.processor_config) + await sync_to_async(ConversationAdapters.clear_openai_conversation_config)(user) update_telemetry_state( request=request, @@ -269,11 +264,7 @@ if not state.demo: metadata={"processor_conversation_type": "openai"}, ) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} @api.post("/config/data/content_type/{content_type}", status_code=200) @requires(["authenticated"], redirect="login_page") @@ -301,24 +292,17 @@ if not state.demo: return {"status": "ok"} @api.post("/config/data/processor/conversation/openai", status_code=200) + @requires(["authenticated"], redirect="login_page") async def set_processor_openai_config_data( request: Request, updated_config: Union[OpenAIProcessorConfig, None], client: Optional[str] = None, ): - _initialize_config() + user = request.user.object - if not state.config.processor or not state.config.processor.conversation: - default_config = constants.default_config - default_conversation_logfile = resolve_absolute_path( - default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore - ) - conversation_logfile = resolve_absolute_path(default_conversation_logfile) - state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore + conversation_config = ConversationProcessorConfig(openai=updated_config) - assert state.config.processor.conversation is not None - state.config.processor.conversation.openai = updated_config - state.processor_config = configure_processor(state.config.processor, state.processor_config) + await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config) update_telemetry_state( request=request, @@ -328,11 +312,7 @@ if not state.demo: metadata={"processor_conversation_type": "conversation"}, ) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} @api.post("/config/data/processor/conversation/offline_chat", status_code=200) async def set_processor_enable_offline_chat_config_data( @@ -341,24 +321,26 @@ if not state.demo: offline_chat_model: Optional[str] = None, client: Optional[str] = None, ): - _initialize_config() + user = request.user.object - if not state.config.processor or not state.config.processor.conversation: - default_config = constants.default_config - default_conversation_logfile = resolve_absolute_path( - default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore + if enable_offline_chat: + conversation_config = ConversationProcessorConfig( + offline_chat=OfflineChatProcessorConfig( + enable_offline_chat=enable_offline_chat, + chat_model=offline_chat_model, + ) ) - conversation_logfile = resolve_absolute_path(default_conversation_logfile) - state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore - assert state.config.processor.conversation is not None - if state.config.processor.conversation.offline_chat is None: - state.config.processor.conversation.offline_chat = OfflineChatProcessorConfig() + await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config) - state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat - if offline_chat_model is not None: - state.config.processor.conversation.offline_chat.chat_model = offline_chat_model - state.processor_config = configure_processor(state.config.processor, state.processor_config) + offline_chat = await ConversationAdapters.get_offline_chat(user) + chat_model = offline_chat.chat_model + if state.gpt4all_processor_config is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model) + + else: + await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user) + state.gpt4all_processor_config = None update_telemetry_state( request=request, @@ -368,11 +350,7 @@ if not state.demo: metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"}, ) - try: - save_config_to_file_updated_state() - return {"status": "ok"} - except Exception as e: - return {"status": "error", "message": str(e)} + return {"status": "ok"} # Create Routes @@ -426,9 +404,6 @@ async def search( if q is None or q == "": logger.warning(f"No query param (q) passed in API call to initiate search") return results - if not state.search_models or not any(state.search_models.__dict__.values()): - logger.warning(f"No search models loaded. Configure a search model before initiating search") - return results # initialize variables user_query = q.strip() @@ -565,8 +540,6 @@ def update( components.append("Search models") if state.content_index: components.append("Content index") - if state.processor_config: - components.append("Conversation processor") components_msg = ", ".join(components) logger.info(f"📪 {components_msg} updated via API") @@ -592,12 +565,11 @@ def chat_history( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ): - perform_chat_checks() + user = request.user.object + perform_chat_checks(user) # Load Conversation History - meta_log = {} - if state.processor_config.conversation: - meta_log = state.processor_config.conversation.meta_log + meta_log = ConversationAdapters.get_conversation_by_user(user=user).conversation_log update_telemetry_state( request=request, @@ -649,30 +621,35 @@ async def chat( referer: Optional[str] = Header(None), host: Optional[str] = Header(None), ) -> Response: - perform_chat_checks() + user = request.user.object + + await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) q = q.replace(f"/{conversation_command.value}", "").strip() + meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log + compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( - request, q, (n or 5), conversation_command + request, meta_log, q, (n or 5), conversation_command ) if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): conversation_command = ConversationCommand.General if conversation_command == ConversationCommand.Help: - model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai" + model_type = "offline" if await ConversationAdapters.has_offline_chat(user) else "openai" formatted_help = help_message.format(model=model_type, version=state.khoj_version) return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) # Get the (streamed) chat response from the LLM of choice. - llm_response = generate_chat_response( + llm_response = await agenerate_chat_response( defiltered_query, - meta_log=state.processor_config.conversation.meta_log, - compiled_references=compiled_references, - inferred_queries=inferred_queries, - conversation_command=conversation_command, + meta_log, + compiled_references, + inferred_queries, + conversation_command, + user, ) if llm_response is None: @@ -681,13 +658,14 @@ async def chat( if stream: return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) + iterator = AsyncIteratorWrapper(llm_response) + # Get the full response from the generator if the stream is not requested. aggregated_gpt_response = "" - while True: - try: - aggregated_gpt_response += next(llm_response) - except StopIteration: + async for item in iterator: + if item is None: break + aggregated_gpt_response += item actual_response = aggregated_gpt_response.split("### compiled references:")[0] @@ -708,44 +686,53 @@ async def chat( async def extract_references_and_questions( request: Request, + meta_log: dict, q: str, n: int, conversation_type: ConversationCommand = ConversationCommand.Default, ): user = request.user.object if request.user.is_authenticated else None - # Load Conversation History - meta_log = state.processor_config.conversation.meta_log # Initialize Variables compiled_references: List[Any] = [] inferred_queries: List[str] = [] - if not EmbeddingsAdapters.user_has_embeddings(user=user): + if conversation_type == ConversationCommand.General: + return compiled_references, inferred_queries, q + + if not await EmbeddingsAdapters.user_has_embeddings(user=user): logger.warning( "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." ) return compiled_references, inferred_queries, q - if conversation_type == ConversationCommand.General: - return compiled_references, inferred_queries, q - # Extract filter terms from user message defiltered_query = q for filter in [DateFilter(), WordFilter(), FileFilter()]: defiltered_query = filter.defilter(defiltered_query) filters_in_query = q.replace(defiltered_query, "").strip() + using_offline_chat = False + # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if state.processor_config.conversation.offline_chat.enable_offline_chat: - loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + if await ConversationAdapters.has_offline_chat(user): + using_offline_chat = True + offline_chat = await ConversationAdapters.get_offline_chat(user) + chat_model = offline_chat.chat_model + if state.gpt4all_processor_config is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model) + + loaded_model = state.gpt4all_processor_config.loaded_model + inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) - elif state.processor_config.conversation.openai_model: - api_key = state.processor_config.conversation.openai_model.api_key - chat_model = state.processor_config.conversation.openai_model.chat_model + elif await ConversationAdapters.has_openai_chat(user): + openai_chat = await ConversationAdapters.get_openai_chat(user) + api_key = openai_chat.api_key + chat_model = openai_chat.chat_model inferred_queries = extract_questions( defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log ) @@ -754,7 +741,7 @@ async def extract_references_and_questions( with timer("Searching knowledge base took", logger): result_list = [] for query in inferred_queries: - n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n + n_items = min(n, 3) if using_offline_chat else n result_list.extend( await search( f"{query} {filters_in_query}", @@ -765,6 +752,8 @@ async def extract_references_and_questions( dedupe=False, ) ) + # Dedupe the results again, as duplicates may be returned across queries. + result_list = text_search.deduplicated_search_responses(result_list) compiled_references = [item.additional["compiled"] for item in result_list] return compiled_references, inferred_queries, defiltered_query diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index be9e8700..8a9e53a7 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,34 +1,50 @@ import logging +import asyncio from datetime import datetime from functools import partial from typing import Iterator, List, Optional, Union +from concurrent.futures import ThreadPoolExecutor from fastapi import HTTPException, Request from khoj.utils import state -from khoj.utils.helpers import ConversationCommand, timer, log_telemetry +from khoj.utils.config import GPT4AllProcessorModel +from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.processor.conversation.openai.gpt import converse from khoj.processor.conversation.gpt4all.chat_model import converse_offline -from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator +from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator from database.models import KhojUser +from database.adapters import ConversationAdapters logger = logging.getLogger(__name__) +executor = ThreadPoolExecutor(max_workers=1) -def perform_chat_checks(): - if ( - state.processor_config - and state.processor_config.conversation - and ( - state.processor_config.conversation.openai_model - or state.processor_config.conversation.gpt4all_model.loaded_model - ) - ): + +def perform_chat_checks(user: KhojUser): + if ConversationAdapters.has_valid_offline_conversation_config( + user + ) or ConversationAdapters.has_valid_openai_conversation_config(user): return - raise HTTPException( - status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it." - ) + raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") + + +async def is_ready_to_chat(user: KhojUser): + has_offline_config = await ConversationAdapters.has_offline_chat(user=user) + has_openai_config = await ConversationAdapters.has_openai_chat(user=user) + + if has_offline_config: + offline_chat = await ConversationAdapters.get_offline_chat(user) + chat_model = offline_chat.chat_model + if state.gpt4all_processor_config is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model) + return True + + ready = has_openai_config or has_offline_config + + if not ready: + raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") def update_telemetry_state( @@ -74,12 +90,22 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.Default +async def construct_conversation_logs(user: KhojUser): + return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log + + +async def agenerate_chat_response(*args): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, generate_chat_response, *args) + + def generate_chat_response( q: str, meta_log: dict, compiled_references: List[str] = [], inferred_queries: List[str] = [], conversation_command: ConversationCommand = ConversationCommand.Default, + user: KhojUser = None, ) -> Union[ThreadedGenerator, Iterator[str]]: def _save_to_conversation_log( q: str, @@ -89,17 +115,14 @@ def generate_chat_response( inferred_queries: List[str], meta_log, ): - state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response]) - state.processor_config.conversation.meta_log["chat"] = message_to_log( + updated_conversation = message_to_log( user_message=q, chat_response=chat_response, user_message_metadata={"created": user_message_time}, khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, conversation_log=meta_log.get("chat", []), ) - - # Load Conversation History - meta_log = state.processor_config.conversation.meta_log + ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) # Initialize Variables user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -116,8 +139,14 @@ def generate_chat_response( meta_log=meta_log, ) - if state.processor_config.conversation.offline_chat.enable_offline_chat: - loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user=user) + conversation_config = ConversationAdapters.get_conversation_config(user) + openai_chat_config = ConversationAdapters.get_openai_conversation_config(user) + if offline_chat_config: + if state.gpt4all_processor_config.loaded_model is None: + state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model) + + loaded_model = state.gpt4all_processor_config.loaded_model chat_response = converse_offline( references=compiled_references, user_query=q, @@ -125,14 +154,14 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_command=conversation_command, - model=state.processor_config.conversation.offline_chat.chat_model, - max_prompt_size=state.processor_config.conversation.max_prompt_size, - tokenizer_name=state.processor_config.conversation.tokenizer, + model=offline_chat_config.chat_model, + max_prompt_size=conversation_config.max_prompt_size, + tokenizer_name=conversation_config.tokenizer, ) - elif state.processor_config.conversation.openai_model: - api_key = state.processor_config.conversation.openai_model.api_key - chat_model = state.processor_config.conversation.openai_model.chat_model + elif openai_chat_config: + api_key = openai_chat_config.api_key + chat_model = openai_chat_config.chat_model chat_response = converse( compiled_references, q, @@ -141,8 +170,8 @@ def generate_chat_response( api_key=api_key, completion_func=partial_completion, conversation_command=conversation_command, - max_prompt_size=state.processor_config.conversation.max_prompt_size, - tokenizer_name=state.processor_config.conversation.tokenizer, + max_prompt_size=conversation_config.max_prompt_size if conversation_config else None, + tokenizer_name=conversation_config.tokenizer if conversation_config else None, ) except Exception as e: diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 1e73c439..1125e653 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -92,7 +92,7 @@ async def update( if dict_to_update is not None: dict_to_update[file.filename] = ( - file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() + file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() # type: ignore ) else: logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}") @@ -181,24 +181,25 @@ def configure_content( files: Optional[dict[str, dict[str, str]]], search_models: SearchModels, regenerate: bool = False, - t: Optional[Union[state.SearchType, str]] = None, + t: Optional[state.SearchType] = None, full_corpus: bool = True, user: KhojUser = None, ) -> Optional[ContentIndex]: content_index = ContentIndex() - if t in [type.value for type in state.SearchType]: - t = state.SearchType(t).value + if t is not None and not t.value in [type.value for type in state.SearchType]: + logger.warning(f"🚨 Invalid search type: {t}") + return None - assert type(t) == str or t == None, f"Invalid search type: {t}" + search_type = t.value if t else None if files is None: - logger.warning(f"🚨 No files to process for {t} search.") + logger.warning(f"🚨 No files to process for {search_type} search.") return None try: # Initialize Org Notes Search - if (t == None or t == state.SearchType.Org.value) and files["org"]: + if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]: logger.info("🦄 Setting up search for orgmode notes") # Extract Entries, Generate Notes Embeddings text_search.setup( @@ -213,7 +214,7 @@ def configure_content( try: # Initialize Markdown Search - if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]: + if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]: logger.info("💎 Setting up search for markdown notes") # Extract Entries, Generate Markdown Embeddings text_search.setup( @@ -229,7 +230,7 @@ def configure_content( try: # Initialize PDF Search - if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]: + if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]: logger.info("🖨️ Setting up search for pdf") # Extract Entries, Generate PDF Embeddings text_search.setup( @@ -245,7 +246,7 @@ def configure_content( try: # Initialize Plaintext Search - if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]: + if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]: logger.info("📄 Setting up search for plaintext") # Extract Entries, Generate Plaintext Embeddings text_search.setup( @@ -262,7 +263,7 @@ def configure_content( try: # Initialize Image Search if ( - (t == None or t == state.SearchType.Image.value) + (search_type == None or search_type == state.SearchType.Image.value) and content_config and content_config.image and search_models.image_search @@ -278,7 +279,7 @@ def configure_content( try: github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() - if (t == None or t == state.SearchType.Github.value) and github_config is not None: + if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None: logger.info("🐙 Setting up search for github") # Extract Entries, Generate Github Embeddings text_search.setup( @@ -296,7 +297,7 @@ def configure_content( try: # Initialize Notion Search notion_config = NotionConfig.objects.filter(user=user).first() - if (t == None or t in state.SearchType.Notion.value) and notion_config: + if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config: logger.info("🔌 Setting up search for notion") text_search.setup( NotionToJsonl, diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 4122c6d0..333d89fa 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -19,7 +19,7 @@ from khoj.utils.rawconfig import ( # Internal Packages from khoj.utils import constants, state -from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config +from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig @@ -83,7 +83,7 @@ if not state.demo: @web_client.get("/config", response_class=HTMLResponse) @requires(["authenticated"], redirect="login_page") def config_page(request: Request): - user = request.user.object if request.user.is_authenticated else None + user = request.user.object enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all()) default_full_config = FullConfig( content_type=None, @@ -100,9 +100,6 @@ if not state.demo: "github": ("github" in enabled_content), "notion": ("notion" in enabled_content), "plaintext": ("plaintext" in enabled_content), - "enable_offline_model": False, - "conversation_openai": False, - "conversation_gpt4all": False, } if state.content_index: @@ -112,13 +109,17 @@ if not state.demo: } ) - if state.processor_config and state.processor_config.conversation: - successfully_configured.update( - { - "conversation_openai": state.processor_config.conversation.openai_model is not None, - "conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None, - } - ) + enabled_chat_config = ConversationAdapters.get_enabled_conversation_settings(user) + + successfully_configured.update( + { + "conversation_openai": enabled_chat_config["openai"], + "enable_offline_model": enabled_chat_config["offline_chat"], + "conversation_gpt4all": state.gpt4all_processor_config.loaded_model is not None + if state.gpt4all_processor_config + else False, + } + ) return templates.TemplateResponse( "config.html", @@ -127,6 +128,7 @@ if not state.demo: "current_config": current_config, "current_model_state": successfully_configured, "anonymous_mode": state.anonymous_mode, + "username": user.username if user else None, }, ) @@ -204,22 +206,22 @@ if not state.demo: ) @web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse) + @requires(["authenticated"], redirect="login_page") def conversation_processor_config_page(request: Request): - default_copy = constants.default_config.copy() - default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore - default_openai_config = OpenAIProcessorConfig( - api_key="", - chat_model=default_processor_config["chat-model"], - ) + user = request.user.object + openai_config = ConversationAdapters.get_openai_conversation_config(user) + + if openai_config: + current_processor_openai_config = OpenAIProcessorConfig( + api_key=openai_config.api_key, + chat_model=openai_config.chat_model, + ) + else: + current_processor_openai_config = OpenAIProcessorConfig( + api_key="", + chat_model="gpt-3.5-turbo", + ) - current_processor_openai_config = ( - state.config.processor.conversation.openai - if state.config - and state.config.processor - and state.config.processor.conversation - and state.config.processor.conversation.openai - else default_openai_config - ) current_processor_openai_config = json.loads(current_processor_openai_config.json()) return templates.TemplateResponse( diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py index 8b92d9db..d7f486af 100644 --- a/src/khoj/search_type/image_search.py +++ b/src/khoj/search_type/image_search.py @@ -236,6 +236,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count= "image_score": f"{hit['image_score']:.9f}", "metadata_score": f"{hit['metadata_score']:.9f}", }, + "corpus_id": hit["corpus_id"], } ) ] diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 36d6a791..dc6593f5 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -14,10 +14,9 @@ from asgiref.sync import sync_to_async # Internal Packages from khoj.utils import state from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer -from khoj.utils.config import TextSearchModel from khoj.utils.models import BaseEncoder from khoj.utils.state import SearchType -from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry +from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.jsonl import load_jsonl from khoj.processor.text_to_jsonl import TextEmbeddings from database.adapters import EmbeddingsAdapters @@ -36,36 +35,6 @@ search_type_to_embeddings_type = { } -def initialize_model(search_config: TextSearchConfig): - "Initialize model for semantic search on text" - torch.set_num_threads(4) - - # If model directory is configured - if search_config.model_directory: - # Convert model directory to absolute path - search_config.model_directory = resolve_absolute_path(search_config.model_directory) - # Create model directory if it doesn't exist - search_config.model_directory.parent.mkdir(parents=True, exist_ok=True) - - # The bi-encoder encodes all entries to use for semantic search - bi_encoder = load_model( - model_dir=search_config.model_directory, - model_name=search_config.encoder, - model_type=search_config.encoder_type or SentenceTransformer, - device=f"{state.device}", - ) - - # The cross-encoder re-ranks the results to improve quality - cross_encoder = load_model( - model_dir=search_config.model_directory, - model_name=search_config.cross_encoder, - model_type=CrossEncoder, - device=f"{state.device}", - ) - - return TextSearchModel(bi_encoder, cross_encoder) - - def extract_entries(jsonl_file) -> List[Entry]: "Load entries from compressed jsonl" return list(map(Entry.from_dict, load_jsonl(jsonl_file))) @@ -176,6 +145,7 @@ def collate_results(hits, dedupe=True): { "entry": hit.raw, "score": hit.distance, + "corpus_id": str(hit.corpus_id), "additional": { "file": hit.file_path, "compiled": hit.compiled, @@ -185,6 +155,28 @@ def collate_results(hits, dedupe=True): ) +def deduplicated_search_responses(hits: List[SearchResponse]): + hit_ids = set() + for hit in hits: + if hit.corpus_id in hit_ids: + continue + + else: + hit_ids.add(hit.corpus_id) + yield SearchResponse.parse_obj( + { + "entry": hit.entry, + "score": hit.score, + "corpus_id": hit.corpus_id, + "additional": { + "file": hit.additional["file"], + "compiled": hit.additional["compiled"], + "heading": hit.additional["heading"], + }, + } + ) + + def rerank_and_sort_results(hits, query): # Score all retrieved entries using the cross-encoder hits = cross_encoder_score(query, hits) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index ee5b4f9f..3c084c4f 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -5,8 +5,7 @@ from enum import Enum import logging from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any +from typing import TYPE_CHECKING, List, Optional, Union, Any from khoj.processor.conversation.gpt4all.utils import download_model # External Packages @@ -19,9 +18,7 @@ logger = logging.getLogger(__name__) # Internal Packages if TYPE_CHECKING: from sentence_transformers import CrossEncoder - from khoj.search_filter.base_filter import BaseFilter from khoj.utils.models import BaseEncoder - from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig class SearchType(str, Enum): @@ -79,31 +76,15 @@ class GPT4AllProcessorConfig: loaded_model: Union[Any, None] = None -class ConversationProcessorConfigModel: +class GPT4AllProcessorModel: def __init__( self, - conversation_config: ConversationProcessorConfig, + chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", ): - self.openai_model = conversation_config.openai - self.gpt4all_model = GPT4AllProcessorConfig() - self.offline_chat = conversation_config.offline_chat or OfflineChatProcessorConfig() - self.max_prompt_size = conversation_config.max_prompt_size - self.tokenizer = conversation_config.tokenizer - self.conversation_logfile = Path(conversation_config.conversation_logfile) - self.chat_session: List[str] = [] - self.meta_log: dict = {} - - if self.offline_chat.enable_offline_chat: - try: - self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model) - except Exception as e: - self.offline_chat.enable_offline_chat = False - self.gpt4all_model.loaded_model = None - logger.error(f"Error while loading offline chat model: {e}", exc_info=True) - else: - self.gpt4all_model.loaded_model = None - - -@dataclass -class ProcessorConfigModel: - conversation: Union[ConversationProcessorConfigModel, None] = None + self.chat_model = chat_model + self.loaded_model = None + try: + self.loaded_model = download_model(self.chat_model) + except ValueError as e: + self.loaded_model = None + logger.error(f"Error while loading offline chat model: {e}", exc_info=True) diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index 181dee04..e9d431c6 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -8,136 +8,14 @@ telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" content_directory = "~/.khoj/content/" empty_config = { - "content-type": { - "org": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz", - "embeddings-file": "~/.khoj/content/org/org_embeddings.pt", - "index-heading-entries": False, - }, - "markdown": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz", - "embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt", - }, - "pdf": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz", - "embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt", - }, - "plaintext": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz", - "embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt", - }, - }, "search-type": { - "symmetric": { - "encoder": "sentence-transformers/all-MiniLM-L6-v2", - "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "model_directory": "~/.khoj/search/symmetric/", - }, - "asymmetric": { - "encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", - "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "model_directory": "~/.khoj/search/asymmetric/", - }, "image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"}, }, - "processor": { - "conversation": { - "openai": { - "api-key": None, - "chat-model": "gpt-3.5-turbo", - }, - "offline-chat": { - "enable-offline-chat": False, - "chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin", - }, - "tokenizer": None, - "max-prompt-size": None, - "conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json", - } - }, } # default app config to use default_config = { - "content-type": { - "org": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz", - "embeddings-file": "~/.khoj/content/org/org_embeddings.pt", - "index-heading-entries": False, - }, - "markdown": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz", - "embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt", - }, - "pdf": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz", - "embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt", - }, - "image": { - "input-directories": None, - "input-filter": None, - "embeddings-file": "~/.khoj/content/image/image_embeddings.pt", - "batch-size": 50, - "use-xmp-metadata": False, - }, - "github": { - "pat-token": None, - "repos": [], - "compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz", - "embeddings-file": "~/.khoj/content/github/github_embeddings.pt", - }, - "notion": { - "token": None, - "compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz", - "embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt", - }, - "plaintext": { - "input-files": None, - "input-filter": None, - "compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz", - "embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt", - }, - }, "search-type": { - "symmetric": { - "encoder": "sentence-transformers/all-MiniLM-L6-v2", - "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "model_directory": "~/.khoj/search/symmetric/", - }, - "asymmetric": { - "encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", - "cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "model_directory": "~/.khoj/search/asymmetric/", - }, "image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"}, }, - "processor": { - "conversation": { - "openai": { - "api-key": None, - "chat-model": "gpt-3.5-turbo", - }, - "offline-chat": { - "enable-offline-chat": False, - "chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin", - }, - "tokenizer": None, - "max-prompt-size": None, - "conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json", - } - }, } diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index e41791f9..0269a9e9 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -15,6 +15,7 @@ from time import perf_counter import torch from typing import Optional, Union, TYPE_CHECKING import uuid +from asgiref.sync import sync_to_async # Internal Packages from khoj.utils import constants @@ -29,6 +30,28 @@ if TYPE_CHECKING: from khoj.utils.rawconfig import AppConfig +class AsyncIteratorWrapper: + def __init__(self, obj): + self._it = iter(obj) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + value = await self.next_async() + except StopAsyncIteration: + return + return value + + @sync_to_async + def next_async(self): + try: + return next(self._it) + except StopIteration: + raise StopAsyncIteration + + def is_none_or_empty(item): return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == "" diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 5d2b3ce4..a469951f 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -67,13 +67,6 @@ class ContentConfig(ConfigBase): notion: Optional[NotionContentConfig] -class TextSearchConfig(ConfigBase): - encoder: str - cross_encoder: str - encoder_type: Optional[str] - model_directory: Optional[Path] - - class ImageSearchConfig(ConfigBase): encoder: str encoder_type: Optional[str] @@ -81,8 +74,6 @@ class ImageSearchConfig(ConfigBase): class SearchConfig(ConfigBase): - asymmetric: Optional[TextSearchConfig] - symmetric: Optional[TextSearchConfig] image: Optional[ImageSearchConfig] @@ -97,11 +88,10 @@ class OfflineChatProcessorConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase): - conversation_logfile: Path - openai: Optional[OpenAIProcessorConfig] - offline_chat: Optional[OfflineChatProcessorConfig] - max_prompt_size: Optional[int] - tokenizer: Optional[str] + openai: Optional[OpenAIProcessorConfig] = None + offline_chat: Optional[OfflineChatProcessorConfig] = None + max_prompt_size: Optional[int] = None + tokenizer: Optional[str] = None class ProcessorConfig(ConfigBase): @@ -125,6 +115,7 @@ class SearchResponse(ConfigBase): score: float cross_score: Optional[float] additional: Optional[dict] + corpus_id: str class Entry: diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index d6169d2a..40806c51 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -10,7 +10,7 @@ from pathlib import Path # Internal Packages from khoj.utils import config as utils_config -from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel +from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU from khoj.utils.rawconfig import FullConfig from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel @@ -21,7 +21,7 @@ search_models = SearchModels() embeddings_model = EmbeddingsModel() cross_encoder_model = CrossEncoderModel() content_index = ContentIndex() -processor_config = ProcessorConfigModel() +gpt4all_processor_config: GPT4AllProcessorModel = None config_file: Path = None verbose: int = 0 host: str = None diff --git a/tests/conftest.py b/tests/conftest.py index ee4b9e57..12ac4f7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ from pathlib import Path import pytest from fastapi.staticfiles import StaticFiles from fastapi import FastAPI -import factory import os from fastapi import FastAPI @@ -13,7 +12,7 @@ app = FastAPI() # Internal Packages -from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware +from khoj.configure import configure_routes, configure_search_types, configure_middleware from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl from khoj.search_type import image_search, text_search from khoj.utils.config import SearchModels @@ -21,13 +20,8 @@ from khoj.utils.constants import web_directory from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ( ContentConfig, - ConversationProcessorConfig, - OfflineChatProcessorConfig, - OpenAIProcessorConfig, - ProcessorConfig, ImageContentConfig, SearchConfig, - TextSearchConfig, ImageSearchConfig, ) from khoj.utils import state, fs_syncer @@ -42,42 +36,25 @@ from database.models import ( GithubRepoConfig, ) +from tests.helpers import ( + UserFactory, + ConversationProcessorConfigFactory, + OpenAIProcessorConversationConfigFactory, + OfflineChatProcessorConversationConfigFactory, +) + @pytest.fixture(autouse=True) def enable_db_access_for_all_tests(db): pass -class UserFactory(factory.django.DjangoModelFactory): - class Meta: - model = KhojUser - - username = factory.Faker("name") - email = factory.Faker("email") - password = factory.Faker("password") - uuid = factory.Faker("uuid4") - - @pytest.fixture(scope="session") def search_config() -> SearchConfig: model_dir = resolve_absolute_path("~/.khoj/search") model_dir.mkdir(parents=True, exist_ok=True) search_config = SearchConfig() - search_config.symmetric = TextSearchConfig( - encoder="sentence-transformers/all-MiniLM-L6-v2", - cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", - model_directory=model_dir / "symmetric/", - encoder_type=None, - ) - - search_config.asymmetric = TextSearchConfig( - encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", - cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2", - model_directory=model_dir / "asymmetric/", - encoder_type=None, - ) - search_config.image = ImageSearchConfig( encoder="sentence-transformers/clip-ViT-B-32", model_directory=model_dir / "image/", @@ -177,55 +154,48 @@ def md_content_config(): return markdown_config -@pytest.fixture(scope="session") -def processor_config(tmp_path_factory): - openai_api_key = os.getenv("OPENAI_API_KEY") - processor_dir = tmp_path_factory.mktemp("processor") - - # The conversation processor is the only configured processor - # It needs an OpenAI API key to work. - if not openai_api_key: - return - - # Setup conversation processor, if OpenAI API key is set - processor_config = ProcessorConfig() - processor_config.conversation = ConversationProcessorConfig( - openai=OpenAIProcessorConfig(api_key=openai_api_key), - conversation_logfile=processor_dir.joinpath("conversation_logs.json"), - ) - - return processor_config - - -@pytest.fixture(scope="session") -def processor_config_offline_chat(tmp_path_factory): - processor_dir = tmp_path_factory.mktemp("processor") - - # Setup conversation processor - processor_config = ProcessorConfig() - offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True) - processor_config.conversation = ConversationProcessorConfig( - offline_chat=offline_chat, - conversation_logfile=processor_dir.joinpath("conversation_logs.json"), - ) - - return processor_config - - -@pytest.fixture(scope="session") -def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig): +@pytest.fixture(scope="function") +def chat_client(search_config: SearchConfig, default_user2: KhojUser): # Initialize app state state.config.search_type = search_config state.SearchType = configure_search_types(state.config) + LocalMarkdownConfig.objects.create( + input_files=None, + input_filter=["tests/data/markdown/*.markdown"], + user=default_user2, + ) + # Index Markdown Content for Search - all_files = fs_syncer.collect_files() + all_files = fs_syncer.collect_files(user=default_user2) state.content_index = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models + state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 ) # Initialize Processor from Config - state.processor_config = configure_processor(processor_config) + if os.getenv("OPENAI_API_KEY"): + OpenAIProcessorConversationConfigFactory(user=default_user2) + + state.anonymous_mode = True + + app = FastAPI() + + configure_routes(app) + configure_middleware(app) + app.mount("/static", StaticFiles(directory=web_directory), name="static") + return TestClient(app) + + +@pytest.fixture(scope="function") +def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser): + # Initialize app state + state.config.search_type = search_config + state.SearchType = configure_search_types(state.config) + + # Initialize Processor from Config + if os.getenv("OPENAI_API_KEY"): + OpenAIProcessorConversationConfigFactory(user=default_user2) + state.anonymous_mode = True app = FastAPI() @@ -249,7 +219,6 @@ def fastapi_app(): def client( content_config: ContentConfig, search_config: SearchConfig, - processor_config: ProcessorConfig, default_user: KhojUser, ): state.config.content_type = content_config @@ -274,7 +243,7 @@ def client( user=default_user, ) - state.processor_config = configure_processor(processor_config) + ConversationProcessorConfigFactory(user=default_user) state.anonymous_mode = True configure_routes(app) @@ -286,25 +255,32 @@ def client( @pytest.fixture(scope="function") def client_offline_chat( search_config: SearchConfig, - processor_config_offline_chat: ProcessorConfig, content_config: ContentConfig, - md_content_config, + default_user2: KhojUser, ): # Initialize app state state.config.content_type = md_content_config state.config.search_type = search_config state.SearchType = configure_search_types(state.config) + LocalMarkdownConfig.objects.create( + input_files=None, + input_filter=["tests/data/markdown/*.markdown"], + user=default_user2, + ) + # Index Markdown Content for Search state.search_models.image_search = image_search.initialize_model(search_config.image) - all_files = fs_syncer.collect_files(state.config.content_type) - state.content_index = configure_content( - state.content_index, state.config.content_type, all_files, state.search_models + all_files = fs_syncer.collect_files(user=default_user2) + configure_content( + state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2 ) # Initialize Processor from Config - state.processor_config = configure_processor(processor_config_offline_chat) + ConversationProcessorConfigFactory(user=default_user2) + OfflineChatProcessorConversationConfigFactory(user=default_user2) + state.anonymous_mode = True configure_routes(app) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 00000000..655c4435 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,51 @@ +import factory +import os + +from database.models import ( + KhojUser, + ConversationProcessorConfig, + OfflineChatProcessorConversationConfig, + OpenAIProcessorConversationConfig, + Conversation, +) + + +class UserFactory(factory.django.DjangoModelFactory): + class Meta: + model = KhojUser + + username = factory.Faker("name") + email = factory.Faker("email") + password = factory.Faker("password") + uuid = factory.Faker("uuid4") + + +class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory): + class Meta: + model = ConversationProcessorConfig + + max_prompt_size = 2000 + tokenizer = None + + +class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory): + class Meta: + model = OfflineChatProcessorConversationConfig + + enable_offline_chat = True + chat_model = "llama-2-7b-chat.ggmlv3.q4_0.bin" + + +class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory): + class Meta: + model = OpenAIProcessorConversationConfig + + api_key = os.getenv("OPENAI_API_KEY") + chat_model = "gpt-3.5-turbo" + + +class ConversationFactory(factory.django.DjangoModelFactory): + class Meta: + model = Conversation + + user = factory.SubFactory(UserFactory) diff --git a/tests/test_client.py b/tests/test_client.py index b77ba07d..1a6b1346 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -119,7 +119,12 @@ def test_get_configured_types_via_api(client, sample_org_data): def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser): # Arrange text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2) + + # Act response = client.get(f"/api/config/types") + + # Assert + assert response.status_code == 200 assert response.json() == ["all", "org", "image"] diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 3e72a7e2..d978fc99 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -9,8 +9,7 @@ from faker import Faker # Internal Packages from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log -from khoj.utils import state - +from tests.helpers import ConversationFactory SKIP_TESTS = True pytestmark = pytest.mark.skipif( @@ -23,7 +22,7 @@ fake = Faker() # Helpers # ---------------------------------------------------------------------------------------------------- -def populate_chat_history(message_list): +def populate_chat_history(message_list, user): # Generate conversation logs conversation_log = {"chat": []} for user_message, llm_message, context in message_list: @@ -33,14 +32,15 @@ def populate_chat_history(message_list): {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, ) - # Update Conversation Metadata Logs in Application State - state.processor_config.conversation.meta_log = conversation_log + # Update Conversation Metadata Logs in Database + ConversationFactory(user=user, conversation_log=conversation_log) # Tests # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat): # Act response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') @@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_chat_history(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_from_chat_history(client_offline_chat, default_user2): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true') @@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_currently_retrieved_content(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), @@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat): ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"') @@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), @@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was I born?"') @@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin reason="Chat director not capable of answering this question yet because it requires extract_questions", ) @pytest.mark.chatquality -def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2): # Arrange message_list = [ ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was I born?"') @@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality -def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2): "Chat director should say don't know as not enough contexts in chat history or retrieved to answer question" # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true') @@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_using_general_command(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_using_general_command(client_offline_chat, default_user2): # Arrange query = urllib.parse.quote("/general Where was Xi Li born?") message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f"/api/chat?q={query}&stream=true") @@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_from_retrieved_content_using_notes_command(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2): # Arrange query = urllib.parse.quote("/notes Where was Xi Li born?") message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f"/api/chat?q={query}&stream=true") @@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_using_file_filter(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_using_file_filter(client_offline_chat, default_user2): # Arrange no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"') answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"') message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8") @@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality -def test_answer_not_known_using_notes_command(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_not_known_using_notes_command(client_offline_chat, default_user2): # Arrange query = urllib.parse.quote("/notes Where was Testatron born?") message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f"/api/chat?q={query}&stream=true") @@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet") @pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) @freeze_time("2023-04-01") def test_answer_requires_current_date_awareness(client_offline_chat): "Chat actor should be able to answer questions relative to current date using provided notes" @@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) @freeze_time("2023-04-01") def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat): "Chat director should be able to answer questions that require date aware aggregation across multiple notes" @@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality -def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get( @@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.") @pytest.mark.chatquality -def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2): # Act response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true') response_message = response.content.decode("utf-8") @@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_ # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality -def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true') @@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat): @pytest.mark.chatquality -def test_answer_chat_history_very_long(client_offline_chat): +@pytest.mark.django_db(transaction=True) +def test_answer_chat_history_very_long(client_offline_chat, default_user2): # Arrange message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true') @@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) def test_answer_requires_multiple_independent_searches(client_offline_chat): "Chat director should be able to answer by doing multiple independent searches for required information" # Act diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index abbd1831..14a73f15 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts # Internal Packages from khoj.processor.conversation.utils import message_to_log -from khoj.utils import state - +from tests.helpers import ConversationFactory +from database.models import KhojUser # Initialize variables for tests api_key = os.getenv("OPENAI_API_KEY") @@ -23,7 +23,7 @@ if api_key is None: # Helpers # ---------------------------------------------------------------------------------------------------- -def populate_chat_history(message_list): +def populate_chat_history(message_list, user=None): # Generate conversation logs conversation_log = {"chat": []} for user_message, gpt_message, context in message_list: @@ -33,13 +33,14 @@ def populate_chat_history(message_list): {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, ) - # Update Conversation Metadata Logs in Application State - state.processor_config.conversation.meta_log = conversation_log + # Update Conversation Metadata Logs in Database + ConversationFactory(user=user, conversation_log=conversation_log) # Tests # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # Act response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') @@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_from_chat_history(chat_client): +def test_answer_from_chat_history(chat_client, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') @@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_from_currently_retrieved_content(chat_client): +def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), @@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client): ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') @@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_from_chat_history_and_previously_retrieved_content(chat_client): +def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), @@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client): ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act - response = chat_client.get(f'/api/chat?q="Where was I born?"') + response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"') response_message = response.content.decode("utf-8") # Assert @@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_from_chat_history_and_currently_retrieved_content(chat_client): +def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') @@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_no_answer_in_chat_history_or_retrieved_content(chat_client): +def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser): "Chat director should say don't know as not enough contexts in chat history or retrieved to answer question" # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') @@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_using_general_command(chat_client): +def test_answer_using_general_command(chat_client, default_user2: KhojUser): # Arrange query = urllib.parse.quote("/general Where was Xi Li born?") message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f"/api/chat?q={query}&stream=true") @@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_from_retrieved_content_using_notes_command(chat_client): +def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser): # Arrange query = urllib.parse.quote("/notes Where was Xi Li born?") message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f"/api/chat?q={query}&stream=true") @@ -205,15 +213,16 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_not_known_using_notes_command(chat_client): +def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser): # Arrange query = urllib.parse.quote("/notes Where was Testatron born?") message_list = [] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act - response = chat_client.get(f"/api/chat?q={query}&stream=true") + response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true") response_message = response.content.decode("utf-8") # Assert @@ -223,6 +232,7 @@ def test_answer_not_known_using_notes_command(chat_client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet") +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality @freeze_time("2023-04-01") def test_answer_requires_current_date_awareness(chat_client): @@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality @freeze_time("2023-04-01") def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client): "Chat director should be able to answer questions that require date aware aggregation across multiple notes" # Act + response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true') response_message = response.content.decode("utf-8") @@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client): +def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get( @@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_ask_for_clarification_if_not_enough_context_in_question(chat_client): +def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): # Act - response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true') + + response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true') response_message = response.content.decode("utf-8") # Assert @@ -301,15 +316,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.xfail(reason="Chat director not capable of answering this question yet") +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality -def test_answer_in_chat_history_beyond_lookback_window(chat_client): +def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list) + populate_chat_history(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') @@ -324,6 +340,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) @pytest.mark.chatquality def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information" @@ -340,10 +357,12 @@ def test_answer_requires_multiple_independent_searches(chat_client): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) def test_answer_using_file_filter(chat_client): "Chat should be able to use search filters in the query" # Act query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"') + response = chat_client.get(f"/api/chat?q={query}&stream=true") response_message = response.content.decode("utf-8") diff --git a/tests/test_text_search.py b/tests/test_text_search.py index af47ffe5..ec8034ef 100644 --- a/tests/test_text_search.py +++ b/tests/test_text_search.py @@ -13,12 +13,11 @@ from khoj.search_type import text_search from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl from khoj.processor.github.github_to_jsonl import GithubToJsonl -from khoj.utils.config import SearchModels -from khoj.utils.fs_syncer import get_org_files, collect_files +from khoj.utils.fs_syncer import collect_files, get_org_files from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig logger = logging.getLogger(__name__) -from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig +from khoj.utils.rawconfig import ContentConfig, SearchConfig # Test