[Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511)

- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration.
- There does _seem_ to be some regression in chat quality, which is most likely attributable to search results.

This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
This commit is contained in:
sabaimran 2023-10-26 11:37:41 -07:00 committed by GitHub
parent a8a82d274a
commit 4b6ec248a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 719 additions and 626 deletions

View file

@ -2,3 +2,5 @@
DJANGO_SETTINGS_MODULE = app.settings
pythonpath = . src
testpaths = tests
markers =
chatquality: marks tests as chatquality (deselect with '-m "not chatquality"')

View file

@ -1,5 +1,4 @@
from typing import Type, TypeVar, List
import uuid
from datetime import date
from django.db import models
@ -21,6 +20,13 @@ from database.models import (
GithubConfig,
Embeddings,
GithubRepoConfig,
Conversation,
ConversationProcessorConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
)
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:
async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo")
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=user_info.get("sub"),
azp=user_info.get("azp"),
email=user_info.get("email"),
name=user_info.get("name"),
given_name=user_info.get("given_name"),
family_name=user_info.get("family_name"),
picture=user_info.get("picture"),
locale=user_info.get("locale"),
sub=token.get("sub"),
azp=token.get("azp"),
email=token.get("email"),
name=token.get("name"),
given_name=token.get("given_name"),
family_name=token.get("family_name"),
picture=token.get("picture"),
locale=token.get("locale"),
user=user,
)
@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
return conversation.first()
return Conversation.objects.create(user=user)
@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)
@staticmethod
def has_any_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).exists()
@staticmethod
def get_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()
@staticmethod
def get_offline_chat_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()
@staticmethod
def has_valid_offline_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
@staticmethod
def has_valid_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()
@staticmethod
def get_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).first()
@staticmethod
def save_conversation(user: KhojUser, conversation_log: dict):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
conversation.update(conversation_log=conversation_log)
else:
Conversation.objects.create(user=user, conversation_log=conversation_log)
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
conversation_config.max_prompt_size = new_config.max_prompt_size
conversation_config.tokenizer = new_config.tokenizer
conversation_config.save()
if new_config.openai:
default_values = {
"api_key": new_config.openai.api_key,
}
if new_config.openai.chat_model:
default_values["chat_model"] = new_config.openai.chat_model
OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
if new_config.offline_chat:
default_values = {
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
}
if new_config.offline_chat.chat_model:
default_values["chat_model"] = new_config.offline_chat.chat_model
OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
@staticmethod
def get_enabled_conversation_settings(user: KhojUser):
openai_config = ConversationAdapters.get_openai_conversation_config(user)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
return {
"openai": True if openai_config is not None else False,
"offline_chat": True
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
else False,
}
@staticmethod
def clear_conversation_config(user: KhojUser):
ConversationProcessorConfig.objects.filter(user=user).delete()
ConversationAdapters.clear_openai_conversation_config(user)
ConversationAdapters.clear_offline_chat_conversation_config(user)
@staticmethod
def clear_openai_conversation_config(user: KhojUser):
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
def clear_offline_chat_conversation_config(user: KhojUser):
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
async def has_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(
user=user, enable_offline_chat=True
).aexists()
@staticmethod
async def get_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()
@staticmethod
async def has_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()
@staticmethod
async def get_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()

View file

@ -0,0 +1,81 @@
# Generated by Django 4.2.5 on 2023-10-18 05:31
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0006_embeddingsdates"),
]
operations = [
migrations.RemoveField(
model_name="conversationprocessorconfig",
name="conversation",
),
migrations.RemoveField(
model_name="conversationprocessorconfig",
name="enable_offline_chat",
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="max_prompt_size",
field=models.IntegerField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="tokenizer",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="user",
field=models.ForeignKey(
default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
preserve_default=False,
),
migrations.CreateModel(
name="OpenAIProcessorConversationConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("api_key", models.CharField(max_length=200)),
("chat_model", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="OfflineChatProcessorConversationConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("enable_offline_chat", models.BooleanField(default=False)),
("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="Conversation",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("conversation_log", models.JSONField()),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
]

View file

@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-10-18 16:46
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0007_remove_conversationprocessorconfig_conversation_and_more"),
]
operations = [
migrations.AlterField(
model_name="conversation",
name="conversation_log",
field=models.JSONField(default=dict),
),
]

View file

@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
conversation = models.JSONField()
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
chat_model = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class OfflineChatProcessorConversationConfig(BaseModel):
enable_offline_chat = models.BooleanField(default=False)
chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin")
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
class Embeddings(BaseModel):

View file

@ -23,12 +23,10 @@ from starlette.authentication import (
from khoj.utils import constants, state
from khoj.utils.config import (
SearchType,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.helpers import merge_dicts
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser
from database.adapters import get_all_users
@ -98,13 +96,6 @@ def configure_server(
# Update Config
state.config = config
# Initialize Processor from Config
try:
state.processor_config = configure_processor(state.config.processor)
except Exception as e:
logger.error(f"🚨 Failed to configure processor", exc_info=True)
raise e
# Initialize Search Models from Config and initialize content
try:
state.config_lock.acquire()
@ -190,103 +181,6 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, {}))
def configure_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):
if not processor_config:
logger.warning("🚨 No Processor configuration available.")
return None
processor = ProcessorConfigModel()
# Initialize Conversation Processor
logger.info("💬 Setting up conversation processor")
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
return processor
def configure_conversation_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):
if (
not processor_config
or not processor_config.conversation
or not processor_config.conversation.conversation_logfile
):
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
conversation_config = processor_config.conversation if processor_config else None
conversation_processor = ConversationProcessorConfigModel(
conversation_config=ConversationProcessorConfig(
conversation_logfile=conversation_logfile,
openai=(conversation_config.openai if (conversation_config is not None) else None),
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
tokenizer=conversation_config.tokenizer if conversation_config else None,
)
)
else:
conversation_processor = ConversationProcessorConfigModel(
conversation_config=processor_config.conversation,
)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
# Load Conversation Logs from Disk
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
conversation_processor.meta_log = state_processor_config.conversation.meta_log
conversation_processor.chat_session = state_processor_config.conversation.chat_session
logger.debug(f"Loaded conversation logs from state")
return conversation_processor
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
with conversation_logfile.open("r") as f:
conversation_processor.meta_log = json.load(f)
logger.debug(f"Loaded conversation logs from {conversation_logfile}")
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}
conversation_processor.chat_session = []
return conversation_processor
@schedule.repeat(schedule.every(17).minutes)
def save_chat_session():
# No need to create empty log file
if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return
# Summarize Conversation Logs for this Session
conversation_log = state.processor_config.conversation.meta_log
session = {
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]),
}
if "session" in conversation_log:
conversation_log["session"].append(session)
else:
conversation_log["session"] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile, indent=2)
state.processor_config.conversation.chat_session = []
logger.info("📩 Saved current chat session to conversation logs")
@schedule.repeat(schedule.every(59).minutes)
def upload_telemetry():
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:

View file

@ -3,6 +3,11 @@
<div class="page">
<div class="section">
{% if anonymous_mode == False %}
<div>
Logged in as {{ username }}
</div>
{% endif %}
<h2 class="section-title">Plugins</h2>
<div class="section-cards">
<div class="card">
@ -257,8 +262,8 @@
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_model_state.enable_offline_model and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_model_state.enable_offline_model and not current_model_state.conversation_gpt4all %}
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
{% endif %}
</h3>
@ -266,12 +271,12 @@
<div class="card-description-row">
<p class="card-description">Setup offline chat</p>
</div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<div id="clear-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
Disable
</button>
</div>
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
<div id="set-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}disabled{% else %}enabled{% endif %}">
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
Enable
</button>

View file

@ -8,21 +8,20 @@ from typing import List, Optional, Union, Any
import asyncio
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Depends
from fastapi import APIRouter, HTTPException, Header, Request
from starlette.authentication import requires
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.configure import configure_processor, configure_server
from khoj.configure import configure_server
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import (
FullConfig,
ProcessorConfig,
SearchConfig,
SearchResponse,
TextContentConfig,
@ -32,16 +31,16 @@ from khoj.utils.rawconfig import (
ConversationProcessorConfig,
OfflineChatProcessorConfig,
)
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.helpers import AsyncIteratorWrapper
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
get_conversation_command,
perform_chat_checks,
generate_chat_response,
agenerate_chat_response,
update_telemetry_state,
is_ready_to_chat,
)
from khoj.processor.conversation.prompts import help_message
from khoj.processor.conversation.openai.gpt import extract_questions
@ -49,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
from fastapi.requests import Request
from database import adapters
from database.adapters import EmbeddingsAdapters
from database.adapters import EmbeddingsAdapters, ConversationAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
@ -114,6 +113,8 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
user=user,
token=config.content_type.notion.token,
)
if config.processor and config.processor.conversation:
ConversationAdapters.set_conversation_processor_config(user, config.processor.conversation)
# If it's a demo instance, prevent updating any of the configuration.
@ -123,8 +124,6 @@ if not state.demo:
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
if state.processor_config is None:
state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
@requires(["authenticated"], redirect="login_page")
@ -238,28 +237,24 @@ if not state.demo:
)
content_object = map_config_to_object(content_type)
if content_object is None:
raise ValueError(f"Invalid content type: {content_type}")
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
@requires(["authenticated"], redirect="login_page")
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
):
if (
not state.config
or not state.config.processor
or not state.config.processor.conversation
or not state.config.processor.conversation.openai
):
return {"status": "ok"}
user = request.user.object
state.config.processor.conversation.openai = None
state.processor_config = configure_processor(state.config.processor, state.processor_config)
await sync_to_async(ConversationAdapters.clear_openai_conversation_config)(user)
update_telemetry_state(
request=request,
@ -269,11 +264,7 @@ if not state.demo:
metadata={"processor_conversation_type": "openai"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
return {"status": "ok"}
@api.post("/config/data/content_type/{content_type}", status_code=200)
@requires(["authenticated"], redirect="login_page")
@ -301,24 +292,17 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/openai", status_code=200)
@requires(["authenticated"], redirect="login_page")
async def set_processor_openai_config_data(
request: Request,
updated_config: Union[OpenAIProcessorConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
if not state.config.processor or not state.config.processor.conversation:
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
conversation_config = ConversationProcessorConfig(openai=updated_config)
assert state.config.processor.conversation is not None
state.config.processor.conversation.openai = updated_config
state.processor_config = configure_processor(state.config.processor, state.processor_config)
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
update_telemetry_state(
request=request,
@ -328,11 +312,7 @@ if not state.demo:
metadata={"processor_conversation_type": "conversation"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
return {"status": "ok"}
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data(
@ -341,24 +321,26 @@ if not state.demo:
offline_chat_model: Optional[str] = None,
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
if not state.config.processor or not state.config.processor.conversation:
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
if enable_offline_chat:
conversation_config = ConversationProcessorConfig(
offline_chat=OfflineChatProcessorConfig(
enable_offline_chat=enable_offline_chat,
chat_model=offline_chat_model,
)
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None
if state.config.processor.conversation.offline_chat is None:
state.config.processor.conversation.offline_chat = OfflineChatProcessorConfig()
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
if offline_chat_model is not None:
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
state.processor_config = configure_processor(state.config.processor, state.processor_config)
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
else:
await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
state.gpt4all_processor_config = None
update_telemetry_state(
request=request,
@ -368,11 +350,7 @@ if not state.demo:
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
return {"status": "ok"}
# Create Routes
@ -426,9 +404,6 @@ async def search(
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results
# initialize variables
user_query = q.strip()
@ -565,8 +540,6 @@ def update(
components.append("Search models")
if state.content_index:
components.append("Content index")
if state.processor_config:
components.append("Conversation processor")
components_msg = ", ".join(components)
logger.info(f"📪 {components_msg} updated via API")
@ -592,12 +565,11 @@ def chat_history(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
perform_chat_checks()
user = request.user.object
perform_chat_checks(user)
# Load Conversation History
meta_log = {}
if state.processor_config.conversation:
meta_log = state.processor_config.conversation.meta_log
meta_log = ConversationAdapters.get_conversation_by_user(user=user).conversation_log
update_telemetry_state(
request=request,
@ -649,30 +621,35 @@ async def chat(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> Response:
perform_chat_checks()
user = request.user.object
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, q, (n or 5), conversation_command
request, meta_log, q, (n or 5), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
if conversation_command == ConversationCommand.Help:
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
model_type = "offline" if await ConversationAdapters.has_offline_chat(user) else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Get the (streamed) chat response from the LLM of choice.
llm_response = generate_chat_response(
llm_response = await agenerate_chat_response(
defiltered_query,
meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
conversation_command=conversation_command,
meta_log,
compiled_references,
inferred_queries,
conversation_command,
user,
)
if llm_response is None:
@ -681,13 +658,14 @@ async def chat(
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
iterator = AsyncIteratorWrapper(llm_response)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
while True:
try:
aggregated_gpt_response += next(llm_response)
except StopIteration:
async for item in iterator:
if item is None:
break
aggregated_gpt_response += item
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
@ -708,44 +686,53 @@ async def chat(
async def extract_references_and_questions(
request: Request,
meta_log: dict,
q: str,
n: int,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
compiled_references: List[Any] = []
inferred_queries: List[str] = []
if not EmbeddingsAdapters.user_has_embeddings(user=user):
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
return compiled_references, inferred_queries, q
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q
# Extract filter terms from user message
defiltered_query = q
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
filters_in_query = q.replace(defiltered_query, "").strip()
using_offline_chat = False
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
if await ConversationAdapters.has_offline_chat(user):
using_offline_chat = True
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
elif await ConversationAdapters.has_openai_chat(user):
openai_chat = await ConversationAdapters.get_openai_chat(user)
api_key = openai_chat.api_key
chat_model = openai_chat.chat_model
inferred_queries = extract_questions(
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
)
@ -754,7 +741,7 @@ async def extract_references_and_questions(
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
n_items = min(n, 3) if using_offline_chat else n
result_list.extend(
await search(
f"{query} {filters_in_query}",
@ -765,6 +752,8 @@ async def extract_references_and_questions(
dedupe=False,
)
)
# Dedupe the results again, as duplicates may be returned across queries.
result_list = text_search.deduplicated_search_responses(result_list)
compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query

View file

@ -1,34 +1,50 @@
import logging
import asyncio
from datetime import datetime
from functools import partial
from typing import Iterator, List, Optional, Union
from concurrent.futures import ThreadPoolExecutor
from fastapi import HTTPException, Request
from khoj.utils import state
from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser
from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
def perform_chat_checks():
if (
state.processor_config
and state.processor_config.conversation
and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
)
):
def perform_chat_checks(user: KhojUser):
if ConversationAdapters.has_valid_offline_conversation_config(
user
) or ConversationAdapters.has_valid_openai_conversation_config(user):
return
raise HTTPException(
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
)
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
async def is_ready_to_chat(user: KhojUser):
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
if has_offline_config:
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
return True
ready = has_openai_config or has_offline_config
if not ready:
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def update_telemetry_state(
@ -74,12 +90,22 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Default
async def construct_conversation_logs(user: KhojUser):
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)
def generate_chat_response(
q: str,
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
) -> Union[ThreadedGenerator, Iterator[str]]:
def _save_to_conversation_log(
q: str,
@ -89,17 +115,14 @@ def generate_chat_response(
inferred_queries: List[str],
meta_log,
):
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@ -116,8 +139,14 @@ def generate_chat_response(
meta_log=meta_log,
)
if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user=user)
conversation_config = ConversationAdapters.get_conversation_config(user)
openai_chat_config = ConversationAdapters.get_openai_conversation_config(user)
if offline_chat_config:
if state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline(
references=compiled_references,
user_query=q,
@ -125,14 +154,14 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_command=conversation_command,
model=state.processor_config.conversation.offline_chat.chat_model,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
model=offline_chat_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
elif openai_chat_config:
api_key = openai_chat_config.api_key
chat_model = openai_chat_config.chat_model
chat_response = converse(
compiled_references,
q,
@ -141,8 +170,8 @@ def generate_chat_response(
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
tokenizer_name=conversation_config.tokenizer if conversation_config else None,
)
except Exception as e:

View file

@ -92,7 +92,7 @@ async def update(
if dict_to_update is not None:
dict_to_update[file.filename] = (
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read()
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() # type: ignore
)
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
@ -181,24 +181,25 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels,
regenerate: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
t: Optional[state.SearchType] = None,
full_corpus: bool = True,
user: KhojUser = None,
) -> Optional[ContentIndex]:
content_index = ContentIndex()
if t in [type.value for type in state.SearchType]:
t = state.SearchType(t).value
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return None
assert type(t) == str or t == None, f"Invalid search type: {t}"
search_type = t.value if t else None
if files is None:
logger.warning(f"🚨 No files to process for {t} search.")
logger.warning(f"🚨 No files to process for {search_type} search.")
return None
try:
# Initialize Org Notes Search
if (t == None or t == state.SearchType.Org.value) and files["org"]:
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
text_search.setup(
@ -213,7 +214,7 @@ def configure_content(
try:
# Initialize Markdown Search
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
text_search.setup(
@ -229,7 +230,7 @@ def configure_content(
try:
# Initialize PDF Search
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
text_search.setup(
@ -245,7 +246,7 @@ def configure_content(
try:
# Initialize Plaintext Search
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
text_search.setup(
@ -262,7 +263,7 @@ def configure_content(
try:
# Initialize Image Search
if (
(t == None or t == state.SearchType.Image.value)
(search_type == None or search_type == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
@ -278,7 +279,7 @@ def configure_content(
try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
text_search.setup(
@ -296,7 +297,7 @@ def configure_content(
try:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
if (t == None or t in state.SearchType.Notion.value) and notion_config:
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
logger.info("🔌 Setting up search for notion")
text_search.setup(
NotionToJsonl,

View file

@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
# Internal Packages
from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
@ -83,7 +83,7 @@ if not state.demo:
@web_client.get("/config", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig(
content_type=None,
@ -100,9 +100,6 @@ if not state.demo:
"github": ("github" in enabled_content),
"notion": ("notion" in enabled_content),
"plaintext": ("plaintext" in enabled_content),
"enable_offline_model": False,
"conversation_openai": False,
"conversation_gpt4all": False,
}
if state.content_index:
@ -112,13 +109,17 @@ if not state.demo:
}
)
if state.processor_config and state.processor_config.conversation:
successfully_configured.update(
{
"conversation_openai": state.processor_config.conversation.openai_model is not None,
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
}
)
enabled_chat_config = ConversationAdapters.get_enabled_conversation_settings(user)
successfully_configured.update(
{
"conversation_openai": enabled_chat_config["openai"],
"enable_offline_model": enabled_chat_config["offline_chat"],
"conversation_gpt4all": state.gpt4all_processor_config.loaded_model is not None
if state.gpt4all_processor_config
else False,
}
)
return templates.TemplateResponse(
"config.html",
@ -127,6 +128,7 @@ if not state.demo:
"current_config": current_config,
"current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
"username": user.username if user else None,
},
)
@ -204,22 +206,22 @@ if not state.demo:
)
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def conversation_processor_config_page(request: Request):
default_copy = constants.default_config.copy()
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
default_openai_config = OpenAIProcessorConfig(
api_key="",
chat_model=default_processor_config["chat-model"],
)
user = request.user.object
openai_config = ConversationAdapters.get_openai_conversation_config(user)
if openai_config:
current_processor_openai_config = OpenAIProcessorConfig(
api_key=openai_config.api_key,
chat_model=openai_config.chat_model,
)
else:
current_processor_openai_config = OpenAIProcessorConfig(
api_key="",
chat_model="gpt-3.5-turbo",
)
current_processor_openai_config = (
state.config.processor.conversation.openai
if state.config
and state.config.processor
and state.config.processor.conversation
and state.config.processor.conversation.openai
else default_openai_config
)
current_processor_openai_config = json.loads(current_processor_openai_config.json())
return templates.TemplateResponse(

View file

@ -236,6 +236,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
"corpus_id": hit["corpus_id"],
}
)
]

View file

@ -14,10 +14,9 @@ from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters
@ -36,36 +35,6 @@ search_type_to_embeddings_type = {
}
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
# If model directory is configured
if search_config.model_directory:
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
device=f"{state.device}",
)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir=search_config.model_directory,
model_name=search_config.cross_encoder,
model_type=CrossEncoder,
device=f"{state.device}",
)
return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]:
"Load entries from compressed jsonl"
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
@ -176,6 +145,7 @@ def collate_results(hits, dedupe=True):
{
"entry": hit.raw,
"score": hit.distance,
"corpus_id": str(hit.corpus_id),
"additional": {
"file": hit.file_path,
"compiled": hit.compiled,
@ -185,6 +155,28 @@ def collate_results(hits, dedupe=True):
)
def deduplicated_search_responses(hits: List[SearchResponse]):
hit_ids = set()
for hit in hits:
if hit.corpus_id in hit_ids:
continue
else:
hit_ids.add(hit.corpus_id)
yield SearchResponse.parse_obj(
{
"entry": hit.entry,
"score": hit.score,
"corpus_id": hit.corpus_id,
"additional": {
"file": hit.additional["file"],
"compiled": hit.additional["compiled"],
"heading": hit.additional["heading"],
},
}
)
def rerank_and_sort_results(hits, query):
# Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits)

View file

@ -5,8 +5,7 @@ from enum import Enum
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
from typing import TYPE_CHECKING, List, Optional, Union, Any
from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
@ -19,9 +18,7 @@ logger = logging.getLogger(__name__)
# Internal Packages
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
class SearchType(str, Enum):
@ -79,31 +76,15 @@ class GPT4AllProcessorConfig:
loaded_model: Union[Any, None] = None
class ConversationProcessorConfigModel:
class GPT4AllProcessorModel:
def __init__(
self,
conversation_config: ConversationProcessorConfig,
chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
):
self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig()
self.offline_chat = conversation_config.offline_chat or OfflineChatProcessorConfig()
self.max_prompt_size = conversation_config.max_prompt_size
self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}
if self.offline_chat.enable_offline_chat:
try:
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
except Exception as e:
self.offline_chat.enable_offline_chat = False
self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
else:
self.gpt4all_model.loaded_model = None
@dataclass
class ProcessorConfigModel:
conversation: Union[ConversationProcessorConfigModel, None] = None
self.chat_model = chat_model
self.loaded_model = None
try:
self.loaded_model = download_model(self.chat_model)
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)

View file

@ -8,136 +8,14 @@ telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
empty_config = {
"content-type": {
"org": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
"index-heading-entries": False,
},
"markdown": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
},
"pdf": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
},
"plaintext": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai": {
"api-key": None,
"chat-model": "gpt-3.5-turbo",
},
"offline-chat": {
"enable-offline-chat": False,
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
},
"tokenizer": None,
"max-prompt-size": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
}
},
}
# default app config to use
default_config = {
"content-type": {
"org": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
"index-heading-entries": False,
},
"markdown": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
},
"pdf": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
},
"image": {
"input-directories": None,
"input-filter": None,
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
"batch-size": 50,
"use-xmp-metadata": False,
},
"github": {
"pat-token": None,
"repos": [],
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
},
"notion": {
"token": None,
"compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz",
"embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt",
},
"plaintext": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai": {
"api-key": None,
"chat-model": "gpt-3.5-turbo",
},
"offline-chat": {
"enable-offline-chat": False,
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
},
"tokenizer": None,
"max-prompt-size": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
}
},
}

View file

@ -15,6 +15,7 @@ from time import perf_counter
import torch
from typing import Optional, Union, TYPE_CHECKING
import uuid
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import constants
@ -29,6 +30,28 @@ if TYPE_CHECKING:
from khoj.utils.rawconfig import AppConfig
class AsyncIteratorWrapper:
def __init__(self, obj):
self._it = iter(obj)
def __aiter__(self):
return self
async def __anext__(self):
try:
value = await self.next_async()
except StopAsyncIteration:
return
return value
@sync_to_async
def next_async(self):
try:
return next(self._it)
except StopIteration:
raise StopAsyncIteration
def is_none_or_empty(item):
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""

View file

@ -67,13 +67,6 @@ class ContentConfig(ConfigBase):
notion: Optional[NotionContentConfig]
class TextSearchConfig(ConfigBase):
encoder: str
cross_encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase):
encoder: str
encoder_type: Optional[str]
@ -81,8 +74,6 @@ class ImageSearchConfig(ConfigBase):
class SearchConfig(ConfigBase):
asymmetric: Optional[TextSearchConfig]
symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig]
@ -97,11 +88,10 @@ class OfflineChatProcessorConfig(ConfigBase):
class ConversationProcessorConfig(ConfigBase):
conversation_logfile: Path
openai: Optional[OpenAIProcessorConfig]
offline_chat: Optional[OfflineChatProcessorConfig]
max_prompt_size: Optional[int]
tokenizer: Optional[str]
openai: Optional[OpenAIProcessorConfig] = None
offline_chat: Optional[OfflineChatProcessorConfig] = None
max_prompt_size: Optional[int] = None
tokenizer: Optional[str] = None
class ProcessorConfig(ConfigBase):
@ -125,6 +115,7 @@ class SearchResponse(ConfigBase):
score: float
cross_score: Optional[float]
additional: Optional[dict]
corpus_id: str
class Entry:

View file

@ -10,7 +10,7 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
@ -21,7 +21,7 @@ search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None
verbose: int = 0
host: str = None

View file

@ -5,7 +5,6 @@ from pathlib import Path
import pytest
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI
import factory
import os
from fastapi import FastAPI
@ -13,7 +12,7 @@ app = FastAPI()
# Internal Packages
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@ -21,13 +20,8 @@ from khoj.utils.constants import web_directory
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
ConversationProcessorConfig,
OfflineChatProcessorConfig,
OpenAIProcessorConfig,
ProcessorConfig,
ImageContentConfig,
SearchConfig,
TextSearchConfig,
ImageSearchConfig,
)
from khoj.utils import state, fs_syncer
@ -42,42 +36,25 @@ from database.models import (
GithubRepoConfig,
)
from tests.helpers import (
UserFactory,
ConversationProcessorConfigFactory,
OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory,
)
@pytest.fixture(autouse=True)
def enable_db_access_for_all_tests(db):
pass
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
search_config.symmetric = TextSearchConfig(
encoder="sentence-transformers/all-MiniLM-L6-v2",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "symmetric/",
encoder_type=None,
)
search_config.asymmetric = TextSearchConfig(
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
model_directory=model_dir / "asymmetric/",
encoder_type=None,
)
search_config.image = ImageSearchConfig(
encoder="sentence-transformers/clip-ViT-B-32",
model_directory=model_dir / "image/",
@ -177,55 +154,48 @@ def md_content_config():
return markdown_config
@pytest.fixture(scope="session")
def processor_config(tmp_path_factory):
openai_api_key = os.getenv("OPENAI_API_KEY")
processor_dir = tmp_path_factory.mktemp("processor")
# The conversation processor is the only configured processor
# It needs an OpenAI API key to work.
if not openai_api_key:
return
# Setup conversation processor, if OpenAI API key is set
processor_config = ProcessorConfig()
processor_config.conversation = ConversationProcessorConfig(
openai=OpenAIProcessorConfig(api_key=openai_api_key),
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)
return processor_config
@pytest.fixture(scope="session")
def processor_config_offline_chat(tmp_path_factory):
processor_dir = tmp_path_factory.mktemp("processor")
# Setup conversation processor
processor_config = ProcessorConfig()
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
processor_config.conversation = ConversationProcessorConfig(
offline_chat=offline_chat,
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)
return processor_config
@pytest.fixture(scope="session")
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
@pytest.fixture(scope="function")
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=default_user2,
)
# Index Markdown Content for Search
all_files = fs_syncer.collect_files()
all_files = fs_syncer.collect_files(user=default_user2)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config)
if os.getenv("OPENAI_API_KEY"):
OpenAIProcessorConversationConfigFactory(user=default_user2)
state.anonymous_mode = True
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
return TestClient(app)
@pytest.fixture(scope="function")
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
# Initialize app state
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
OpenAIProcessorConversationConfigFactory(user=default_user2)
state.anonymous_mode = True
app = FastAPI()
@ -249,7 +219,6 @@ def fastapi_app():
def client(
content_config: ContentConfig,
search_config: SearchConfig,
processor_config: ProcessorConfig,
default_user: KhojUser,
):
state.config.content_type = content_config
@ -274,7 +243,7 @@ def client(
user=default_user,
)
state.processor_config = configure_processor(processor_config)
ConversationProcessorConfigFactory(user=default_user)
state.anonymous_mode = True
configure_routes(app)
@ -286,25 +255,32 @@ def client(
@pytest.fixture(scope="function")
def client_offline_chat(
search_config: SearchConfig,
processor_config_offline_chat: ProcessorConfig,
content_config: ContentConfig,
md_content_config,
default_user2: KhojUser,
):
# Initialize app state
state.config.content_type = md_content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
LocalMarkdownConfig.objects.create(
input_files=None,
input_filter=["tests/data/markdown/*.markdown"],
user=default_user2,
)
# Index Markdown Content for Search
state.search_models.image_search = image_search.initialize_model(search_config.image)
all_files = fs_syncer.collect_files(state.config.content_type)
state.content_index = configure_content(
state.content_index, state.config.content_type, all_files, state.search_models
all_files = fs_syncer.collect_files(user=default_user2)
configure_content(
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
)
# Initialize Processor from Config
state.processor_config = configure_processor(processor_config_offline_chat)
ConversationProcessorConfigFactory(user=default_user2)
OfflineChatProcessorConversationConfigFactory(user=default_user2)
state.anonymous_mode = True
configure_routes(app)

51
tests/helpers.py Normal file
View file

@ -0,0 +1,51 @@
import factory
import os
from database.models import (
KhojUser,
ConversationProcessorConfig,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
Conversation,
)
class UserFactory(factory.django.DjangoModelFactory):
class Meta:
model = KhojUser
username = factory.Faker("name")
email = factory.Faker("email")
password = factory.Faker("password")
uuid = factory.Faker("uuid4")
class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = ConversationProcessorConfig
max_prompt_size = 2000
tokenizer = None
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OfflineChatProcessorConversationConfig
enable_offline_chat = True
chat_model = "llama-2-7b-chat.ggmlv3.q4_0.bin"
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OpenAIProcessorConversationConfig
api_key = os.getenv("OPENAI_API_KEY")
chat_model = "gpt-3.5-turbo"
class ConversationFactory(factory.django.DjangoModelFactory):
class Meta:
model = Conversation
user = factory.SubFactory(UserFactory)

View file

@ -119,7 +119,12 @@ def test_get_configured_types_via_api(client, sample_org_data):
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
# Arrange
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
# Act
response = client.get(f"/api/config/types")
# Assert
assert response.status_code == 200
assert response.json() == ["all", "org", "image"]

View file

@ -9,8 +9,7 @@ from faker import Faker
# Internal Packages
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import message_to_log
from khoj.utils import state
from tests.helpers import ConversationFactory
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
@ -23,7 +22,7 @@ fake = Faker()
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
def populate_chat_history(message_list, user):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, llm_message, context in message_list:
@ -33,14 +32,15 @@ def populate_chat_history(message_list):
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
# Update Conversation Metadata Logs in Application State
state.processor_config.conversation.meta_log = conversation_log
# Update Conversation Metadata Logs in Database
ConversationFactory(user=user, conversation_log=conversation_log)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
# Act
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_currently_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
reason="Chat director not capable of answering this question yet because it requires extract_questions",
)
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_using_general_command(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_using_general_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_using_file_filter(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_using_file_filter(client_offline_chat, default_user2):
# Arrange
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_not_known_using_notes_command(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(client_offline_chat):
"Chat actor should be able to answer questions relative to current date using provided notes"
@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(
@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
# Act
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
@pytest.mark.chatquality
def test_answer_chat_history_very_long(client_offline_chat):
@pytest.mark.django_db(transaction=True)
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
# Arrange
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_answer_requires_multiple_independent_searches(client_offline_chat):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act

View file

@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts
# Internal Packages
from khoj.processor.conversation.utils import message_to_log
from khoj.utils import state
from tests.helpers import ConversationFactory
from database.models import KhojUser
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@ -23,7 +23,7 @@ if api_key is None:
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
def populate_chat_history(message_list, user=None):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
@ -33,13 +33,14 @@ def populate_chat_history(message_list):
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
# Update Conversation Metadata Logs in Application State
state.processor_config.conversation.meta_log = conversation_log
# Update Conversation Metadata Logs in Database
ConversationFactory(user=user, conversation_log=conversation_log)
# Tests
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_chat_history(chat_client):
def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_currently_retrieved_content(chat_client):
def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
["Testatron was born on 1st April 1984 in Testville."],
),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8")
# Assert
@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_using_general_command(chat_client):
def test_answer_using_general_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/general Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_from_retrieved_content_using_notes_command(chat_client):
def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/notes Where was Xi Li born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
@ -205,15 +213,16 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_not_known_using_notes_command(chat_client):
def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser):
# Arrange
query = urllib.parse.quote("/notes Where was Testatron born?")
message_list = []
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
@ -223,6 +232,7 @@ def test_answer_not_known_using_notes_command(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_current_date_awareness(chat_client):
@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
@freeze_time("2023-04-01")
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client):
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(
@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
# Assert
@ -301,15 +316,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_in_chat_history_beyond_lookback_window(chat_client):
def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser):
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
populate_chat_history(message_list)
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
@ -324,6 +340,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
@pytest.mark.chatquality
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
@ -340,10 +357,12 @@ def test_answer_requires_multiple_independent_searches(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")

View file

@ -13,12 +13,11 @@ from khoj.search_type import text_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
from khoj.processor.github.github_to_jsonl import GithubToJsonl
from khoj.utils.config import SearchModels
from khoj.utils.fs_syncer import get_org_files, collect_files
from khoj.utils.fs_syncer import collect_files, get_org_files
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
logger = logging.getLogger(__name__)
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
from khoj.utils.rawconfig import ContentConfig, SearchConfig
# Test