mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
[Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511)
- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration. - There does _seem_ to be some regression in chat quality, which is most likely attributable to search results. This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
This commit is contained in:
parent
a8a82d274a
commit
4b6ec248a6
24 changed files with 719 additions and 626 deletions
|
@ -2,3 +2,5 @@
|
|||
DJANGO_SETTINGS_MODULE = app.settings
|
||||
pythonpath = . src
|
||||
testpaths = tests
|
||||
markers =
|
||||
chatquality: marks tests as chatquality (deselect with '-m "not chatquality"')
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Type, TypeVar, List
|
||||
import uuid
|
||||
from datetime import date
|
||||
|
||||
from django.db import models
|
||||
|
@ -21,6 +20,13 @@ from database.models import (
|
|||
GithubConfig,
|
||||
Embeddings,
|
||||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ConversationProcessorConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
)
|
||||
from khoj.utils.rawconfig import (
|
||||
ConversationProcessorConfig as UserConversationProcessorConfig,
|
||||
)
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
|||
|
||||
|
||||
async def create_google_user(token: dict) -> KhojUser:
|
||||
user_info = token.get("userinfo")
|
||||
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
|
||||
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
|
||||
await user.asave()
|
||||
await GoogleUser.objects.acreate(
|
||||
sub=user_info.get("sub"),
|
||||
azp=user_info.get("azp"),
|
||||
email=user_info.get("email"),
|
||||
name=user_info.get("name"),
|
||||
given_name=user_info.get("given_name"),
|
||||
family_name=user_info.get("family_name"),
|
||||
picture=user_info.get("picture"),
|
||||
locale=user_info.get("locale"),
|
||||
sub=token.get("sub"),
|
||||
azp=token.get("azp"),
|
||||
email=token.get("email"),
|
||||
name=token.get("name"),
|
||||
given_name=token.get("given_name"),
|
||||
family_name=token.get("family_name"),
|
||||
picture=token.get("picture"),
|
||||
locale=token.get("locale"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
|
@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||
return config
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(user: KhojUser):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
if conversation.exists():
|
||||
return conversation.first()
|
||||
return Conversation.objects.create(user=user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_conversation_by_user(user: KhojUser):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
if await conversation.aexists():
|
||||
return await conversation.afirst()
|
||||
return await Conversation.objects.acreate(user=user)
|
||||
|
||||
@staticmethod
|
||||
def has_any_conversation_config(user: KhojUser):
|
||||
return ConversationProcessorConfig.objects.filter(user=user).exists()
|
||||
|
||||
@staticmethod
|
||||
def get_openai_conversation_config(user: KhojUser):
|
||||
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()
|
||||
|
||||
@staticmethod
|
||||
def get_offline_chat_conversation_config(user: KhojUser):
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_offline_conversation_config(user: KhojUser):
|
||||
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
|
||||
|
||||
@staticmethod
|
||||
def has_valid_openai_conversation_config(user: KhojUser):
|
||||
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_config(user: KhojUser):
|
||||
return ConversationProcessorConfig.objects.filter(user=user).first()
|
||||
|
||||
@staticmethod
|
||||
def save_conversation(user: KhojUser, conversation_log: dict):
|
||||
conversation = Conversation.objects.filter(user=user)
|
||||
if conversation.exists():
|
||||
conversation.update(conversation_log=conversation_log)
|
||||
else:
|
||||
Conversation.objects.create(user=user, conversation_log=conversation_log)
|
||||
|
||||
@staticmethod
|
||||
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
|
||||
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
|
||||
conversation_config.max_prompt_size = new_config.max_prompt_size
|
||||
conversation_config.tokenizer = new_config.tokenizer
|
||||
conversation_config.save()
|
||||
|
||||
if new_config.openai:
|
||||
default_values = {
|
||||
"api_key": new_config.openai.api_key,
|
||||
}
|
||||
if new_config.openai.chat_model:
|
||||
default_values["chat_model"] = new_config.openai.chat_model
|
||||
|
||||
OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
|
||||
|
||||
if new_config.offline_chat:
|
||||
default_values = {
|
||||
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
|
||||
}
|
||||
|
||||
if new_config.offline_chat.chat_model:
|
||||
default_values["chat_model"] = new_config.offline_chat.chat_model
|
||||
|
||||
OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
|
||||
|
||||
@staticmethod
|
||||
def get_enabled_conversation_settings(user: KhojUser):
|
||||
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
|
||||
|
||||
return {
|
||||
"openai": True if openai_config is not None else False,
|
||||
"offline_chat": True
|
||||
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
|
||||
else False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def clear_conversation_config(user: KhojUser):
|
||||
ConversationProcessorConfig.objects.filter(user=user).delete()
|
||||
ConversationAdapters.clear_openai_conversation_config(user)
|
||||
ConversationAdapters.clear_offline_chat_conversation_config(user)
|
||||
|
||||
@staticmethod
|
||||
def clear_openai_conversation_config(user: KhojUser):
|
||||
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
def clear_offline_chat_conversation_config(user: KhojUser):
|
||||
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
async def has_offline_chat(user: KhojUser):
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter(
|
||||
user=user, enable_offline_chat=True
|
||||
).aexists()
|
||||
|
||||
@staticmethod
|
||||
async def get_offline_chat(user: KhojUser):
|
||||
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||
|
||||
@staticmethod
|
||||
async def has_openai_chat(user: KhojUser):
|
||||
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()
|
||||
|
||||
@staticmethod
|
||||
async def get_openai_chat(user: KhojUser):
|
||||
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||
|
||||
|
||||
class EmbeddingsAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Generated by Django 4.2.5 on 2023-10-18 05:31
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0006_embeddingsdates"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="conversation",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="enable_offline_chat",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="max_prompt_size",
|
||||
field=models.IntegerField(blank=True, default=None, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="tokenizer",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversationprocessorconfig",
|
||||
name="user",
|
||||
field=models.ForeignKey(
|
||||
default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
preserve_default=False,
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="OpenAIProcessorConversationConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("api_key", models.CharField(max_length=200)),
|
||||
("chat_model", models.CharField(max_length=200)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="OfflineChatProcessorConversationConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("enable_offline_chat", models.BooleanField(default=False)),
|
||||
("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Conversation",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("conversation_log", models.JSONField()),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.2.5 on 2023-10-18 16:46
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0007_remove_conversationprocessorconfig_conversation_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="conversation",
|
||||
name="conversation_log",
|
||||
field=models.JSONField(default=dict),
|
||||
),
|
||||
]
|
|
@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel):
|
|||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class ConversationProcessorConfig(BaseModel):
|
||||
conversation = models.JSONField()
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
api_key = models.CharField(max_length=200)
|
||||
chat_model = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class OfflineChatProcessorConversationConfig(BaseModel):
|
||||
enable_offline_chat = models.BooleanField(default=False)
|
||||
chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin")
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class ConversationProcessorConfig(BaseModel):
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
conversation_log = models.JSONField(default=dict)
|
||||
|
||||
|
||||
class Embeddings(BaseModel):
|
||||
|
|
|
@ -23,12 +23,10 @@ from starlette.authentication import (
|
|||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import (
|
||||
SearchType,
|
||||
ProcessorConfigModel,
|
||||
ConversationProcessorConfigModel,
|
||||
)
|
||||
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from database.models import KhojUser
|
||||
from database.adapters import get_all_users
|
||||
|
@ -98,13 +96,6 @@ def configure_server(
|
|||
# Update Config
|
||||
state.config = config
|
||||
|
||||
# Initialize Processor from Config
|
||||
try:
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure processor", exc_info=True)
|
||||
raise e
|
||||
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.config_lock.acquire()
|
||||
|
@ -190,103 +181,6 @@ def configure_search_types(config: FullConfig):
|
|||
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
||||
|
||||
|
||||
def configure_processor(
|
||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||
):
|
||||
if not processor_config:
|
||||
logger.warning("🚨 No Processor configuration available.")
|
||||
return None
|
||||
|
||||
processor = ProcessorConfigModel()
|
||||
|
||||
# Initialize Conversation Processor
|
||||
logger.info("💬 Setting up conversation processor")
|
||||
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def configure_conversation_processor(
|
||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||
):
|
||||
if (
|
||||
not processor_config
|
||||
or not processor_config.conversation
|
||||
or not processor_config.conversation.conversation_logfile
|
||||
):
|
||||
default_config = constants.default_config
|
||||
default_conversation_logfile = resolve_absolute_path(
|
||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||
conversation_config = processor_config.conversation if processor_config else None
|
||||
conversation_processor = ConversationProcessorConfigModel(
|
||||
conversation_config=ConversationProcessorConfig(
|
||||
conversation_logfile=conversation_logfile,
|
||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
||||
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
|
||||
tokenizer=conversation_config.tokenizer if conversation_config else None,
|
||||
)
|
||||
)
|
||||
else:
|
||||
conversation_processor = ConversationProcessorConfigModel(
|
||||
conversation_config=processor_config.conversation,
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
||||
|
||||
# Load Conversation Logs from Disk
|
||||
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
|
||||
conversation_processor.meta_log = state_processor_config.conversation.meta_log
|
||||
conversation_processor.chat_session = state_processor_config.conversation.chat_session
|
||||
logger.debug(f"Loaded conversation logs from state")
|
||||
return conversation_processor
|
||||
|
||||
if conversation_logfile.is_file():
|
||||
# Load Metadata Logs from Conversation Logfile
|
||||
with conversation_logfile.open("r") as f:
|
||||
conversation_processor.meta_log = json.load(f)
|
||||
logger.debug(f"Loaded conversation logs from {conversation_logfile}")
|
||||
else:
|
||||
# Initialize Conversation Logs
|
||||
conversation_processor.meta_log = {}
|
||||
conversation_processor.chat_session = []
|
||||
|
||||
return conversation_processor
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(17).minutes)
|
||||
def save_chat_session():
|
||||
# No need to create empty log file
|
||||
if not (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and state.processor_config.conversation.meta_log
|
||||
and state.processor_config.conversation.chat_session
|
||||
):
|
||||
return
|
||||
|
||||
# Summarize Conversation Logs for this Session
|
||||
conversation_log = state.processor_config.conversation.meta_log
|
||||
session = {
|
||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||
"session-end": len(conversation_log["chat"]),
|
||||
}
|
||||
if "session" in conversation_log:
|
||||
conversation_log["session"].append(session)
|
||||
else:
|
||||
conversation_log["session"] = [session]
|
||||
|
||||
# Save Conversation Metadata Logs to Disk
|
||||
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
|
||||
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
|
||||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||
json.dump(conversation_log, logfile, indent=2)
|
||||
|
||||
state.processor_config.conversation.chat_session = []
|
||||
logger.info("📩 Saved current chat session to conversation logs")
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(59).minutes)
|
||||
def upload_telemetry():
|
||||
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
|
||||
|
|
|
@ -3,6 +3,11 @@
|
|||
|
||||
<div class="page">
|
||||
<div class="section">
|
||||
{% if anonymous_mode == False %}
|
||||
<div>
|
||||
Logged in as {{ username }}
|
||||
</div>
|
||||
{% endif %}
|
||||
<h2 class="section-title">Plugins</h2>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
|
@ -257,8 +262,8 @@
|
|||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Offline Chat
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_model_state.enable_offline_model and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% if current_model_state.enable_offline_model and not current_model_state.conversation_gpt4all %}
|
||||
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
||||
{% endif %}
|
||||
</h3>
|
||||
|
@ -266,12 +271,12 @@
|
|||
<div class="card-description-row">
|
||||
<p class="card-description">Setup offline chat</p>
|
||||
</div>
|
||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}enabled{% else %}disabled{% endif %}">
|
||||
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}disabled{% else %}enabled{% endif %}">
|
||||
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
||||
Enable
|
||||
</button>
|
||||
|
|
|
@ -8,21 +8,20 @@ from typing import List, Optional, Union, Any
|
|||
import asyncio
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request, Depends
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_server
|
||||
from khoj.configure import configure_server
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
|
||||
from khoj.utils.rawconfig import (
|
||||
FullConfig,
|
||||
ProcessorConfig,
|
||||
SearchConfig,
|
||||
SearchResponse,
|
||||
TextContentConfig,
|
||||
|
@ -32,16 +31,16 @@ from khoj.utils.rawconfig import (
|
|||
ConversationProcessorConfig,
|
||||
OfflineChatProcessorConfig,
|
||||
)
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils import state, constants
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from khoj.utils.helpers import AsyncIteratorWrapper
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from khoj.routers.helpers import (
|
||||
get_conversation_command,
|
||||
perform_chat_checks,
|
||||
generate_chat_response,
|
||||
agenerate_chat_response,
|
||||
update_telemetry_state,
|
||||
is_ready_to_chat,
|
||||
)
|
||||
from khoj.processor.conversation.prompts import help_message
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
|
@ -49,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
|
|||
from fastapi.requests import Request
|
||||
|
||||
from database import adapters
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
from database.adapters import EmbeddingsAdapters, ConversationAdapters
|
||||
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
||||
|
||||
|
||||
|
@ -114,6 +113,8 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
|
|||
user=user,
|
||||
token=config.content_type.notion.token,
|
||||
)
|
||||
if config.processor and config.processor.conversation:
|
||||
ConversationAdapters.set_conversation_processor_config(user, config.processor.conversation)
|
||||
|
||||
|
||||
# If it's a demo instance, prevent updating any of the configuration.
|
||||
|
@ -123,8 +124,6 @@ if not state.demo:
|
|||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||
if state.processor_config is None:
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
|
||||
@api.get("/config/data", response_model=FullConfig)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
|
@ -238,28 +237,24 @@ if not state.demo:
|
|||
)
|
||||
|
||||
content_object = map_config_to_object(content_type)
|
||||
if content_object is None:
|
||||
raise ValueError(f"Invalid content type: {content_type}")
|
||||
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
||||
|
||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def remove_processor_conversation_config_data(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
if (
|
||||
not state.config
|
||||
or not state.config.processor
|
||||
or not state.config.processor.conversation
|
||||
or not state.config.processor.conversation.openai
|
||||
):
|
||||
return {"status": "ok"}
|
||||
user = request.user.object
|
||||
|
||||
state.config.processor.conversation.openai = None
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
await sync_to_async(ConversationAdapters.clear_openai_conversation_config)(user)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -269,11 +264,7 @@ if not state.demo:
|
|||
metadata={"processor_conversation_type": "openai"},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
|
@ -301,24 +292,17 @@ if not state.demo:
|
|||
return {"status": "ok"}
|
||||
|
||||
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
async def set_processor_openai_config_data(
|
||||
request: Request,
|
||||
updated_config: Union[OpenAIProcessorConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
user = request.user.object
|
||||
|
||||
if not state.config.processor or not state.config.processor.conversation:
|
||||
default_config = constants.default_config
|
||||
default_conversation_logfile = resolve_absolute_path(
|
||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||
conversation_config = ConversationProcessorConfig(openai=updated_config)
|
||||
|
||||
assert state.config.processor.conversation is not None
|
||||
state.config.processor.conversation.openai = updated_config
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -328,11 +312,7 @@ if not state.demo:
|
|||
metadata={"processor_conversation_type": "conversation"},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
|
||||
async def set_processor_enable_offline_chat_config_data(
|
||||
|
@ -341,24 +321,26 @@ if not state.demo:
|
|||
offline_chat_model: Optional[str] = None,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
user = request.user.object
|
||||
|
||||
if not state.config.processor or not state.config.processor.conversation:
|
||||
default_config = constants.default_config
|
||||
default_conversation_logfile = resolve_absolute_path(
|
||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||
if enable_offline_chat:
|
||||
conversation_config = ConversationProcessorConfig(
|
||||
offline_chat=OfflineChatProcessorConfig(
|
||||
enable_offline_chat=enable_offline_chat,
|
||||
chat_model=offline_chat_model,
|
||||
)
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||
|
||||
assert state.config.processor.conversation is not None
|
||||
if state.config.processor.conversation.offline_chat is None:
|
||||
state.config.processor.conversation.offline_chat = OfflineChatProcessorConfig()
|
||||
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
|
||||
|
||||
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
|
||||
if offline_chat_model is not None:
|
||||
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||
chat_model = offline_chat.chat_model
|
||||
if state.gpt4all_processor_config is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
||||
|
||||
else:
|
||||
await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
|
||||
state.gpt4all_processor_config = None
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -368,11 +350,7 @@ if not state.demo:
|
|||
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# Create Routes
|
||||
|
@ -426,9 +404,6 @@ async def search(
|
|||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
if not state.search_models or not any(state.search_models.__dict__.values()):
|
||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
||||
return results
|
||||
|
||||
# initialize variables
|
||||
user_query = q.strip()
|
||||
|
@ -565,8 +540,6 @@ def update(
|
|||
components.append("Search models")
|
||||
if state.content_index:
|
||||
components.append("Content index")
|
||||
if state.processor_config:
|
||||
components.append("Conversation processor")
|
||||
components_msg = ", ".join(components)
|
||||
logger.info(f"📪 {components_msg} updated via API")
|
||||
|
||||
|
@ -592,12 +565,11 @@ def chat_history(
|
|||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
perform_chat_checks()
|
||||
user = request.user.object
|
||||
perform_chat_checks(user)
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = {}
|
||||
if state.processor_config.conversation:
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
meta_log = ConversationAdapters.get_conversation_by_user(user=user).conversation_log
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -649,30 +621,35 @@ async def chat(
|
|||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
) -> Response:
|
||||
perform_chat_checks()
|
||||
user = request.user.object
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
|
||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||
|
||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, q, (n or 5), conversation_command
|
||||
request, meta_log, q, (n or 5), conversation_command
|
||||
)
|
||||
|
||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||
conversation_command = ConversationCommand.General
|
||||
|
||||
if conversation_command == ConversationCommand.Help:
|
||||
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
|
||||
model_type = "offline" if await ConversationAdapters.has_offline_chat(user) else "openai"
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
|
||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response = generate_chat_response(
|
||||
llm_response = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log=state.processor_config.conversation.meta_log,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
conversation_command=conversation_command,
|
||||
meta_log,
|
||||
compiled_references,
|
||||
inferred_queries,
|
||||
conversation_command,
|
||||
user,
|
||||
)
|
||||
|
||||
if llm_response is None:
|
||||
|
@ -681,13 +658,14 @@ async def chat(
|
|||
if stream:
|
||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
iterator = AsyncIteratorWrapper(llm_response)
|
||||
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
aggregated_gpt_response = ""
|
||||
while True:
|
||||
try:
|
||||
aggregated_gpt_response += next(llm_response)
|
||||
except StopIteration:
|
||||
async for item in iterator:
|
||||
if item is None:
|
||||
break
|
||||
aggregated_gpt_response += item
|
||||
|
||||
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
||||
|
||||
|
@ -708,44 +686,53 @@ async def chat(
|
|||
|
||||
async def extract_references_and_questions(
|
||||
request: Request,
|
||||
meta_log: dict,
|
||||
q: str,
|
||||
n: int,
|
||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# Initialize Variables
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[str] = []
|
||||
|
||||
if not EmbeddingsAdapters.user_has_embeddings(user=user):
|
||||
if conversation_type == ConversationCommand.General:
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
|
||||
logger.warning(
|
||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||
)
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
if conversation_type == ConversationCommand.General:
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
# Extract filter terms from user message
|
||||
defiltered_query = q
|
||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(defiltered_query)
|
||||
filters_in_query = q.replace(defiltered_query, "").strip()
|
||||
|
||||
using_offline_chat = False
|
||||
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
if await ConversationAdapters.has_offline_chat(user):
|
||||
using_offline_chat = True
|
||||
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||
chat_model = offline_chat.chat_model
|
||||
if state.gpt4all_processor_config is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
||||
|
||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
elif await ConversationAdapters.has_openai_chat(user):
|
||||
openai_chat = await ConversationAdapters.get_openai_chat(user)
|
||||
api_key = openai_chat.api_key
|
||||
chat_model = openai_chat.chat_model
|
||||
inferred_queries = extract_questions(
|
||||
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
||||
)
|
||||
|
@ -754,7 +741,7 @@ async def extract_references_and_questions(
|
|||
with timer("Searching knowledge base took", logger):
|
||||
result_list = []
|
||||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
|
||||
n_items = min(n, 3) if using_offline_chat else n
|
||||
result_list.extend(
|
||||
await search(
|
||||
f"{query} {filters_in_query}",
|
||||
|
@ -765,6 +752,8 @@ async def extract_references_and_questions(
|
|||
dedupe=False,
|
||||
)
|
||||
)
|
||||
# Dedupe the results again, as duplicates may be returned across queries.
|
||||
result_list = text_search.deduplicated_search_responses(result_list)
|
||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
return compiled_references, inferred_queries, defiltered_query
|
||||
|
|
|
@ -1,34 +1,50 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
||||
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||
from database.models import KhojUser
|
||||
from database.adapters import ConversationAdapters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def perform_chat_checks():
|
||||
if (
|
||||
state.processor_config
|
||||
and state.processor_config.conversation
|
||||
and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
)
|
||||
):
|
||||
|
||||
def perform_chat_checks(user: KhojUser):
|
||||
if ConversationAdapters.has_valid_offline_conversation_config(
|
||||
user
|
||||
) or ConversationAdapters.has_valid_openai_conversation_config(user):
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||
|
||||
|
||||
async def is_ready_to_chat(user: KhojUser):
|
||||
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
|
||||
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
|
||||
|
||||
if has_offline_config:
|
||||
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||
chat_model = offline_chat.chat_model
|
||||
if state.gpt4all_processor_config is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
||||
return True
|
||||
|
||||
ready = has_openai_config or has_offline_config
|
||||
|
||||
if not ready:
|
||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||
|
||||
|
||||
def update_telemetry_state(
|
||||
|
@ -74,12 +90,22 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.Default
|
||||
|
||||
|
||||
async def construct_conversation_logs(user: KhojUser):
|
||||
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
|
||||
async def agenerate_chat_response(*args):
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||
|
||||
|
||||
def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
compiled_references: List[str] = [],
|
||||
inferred_queries: List[str] = [],
|
||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||
user: KhojUser = None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
|
@ -89,17 +115,14 @@ def generate_chat_response(
|
|||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
||||
|
||||
# Initialize Variables
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
@ -116,8 +139,14 @@ def generate_chat_response(
|
|||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user=user)
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||
if offline_chat_config:
|
||||
if state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
|
||||
|
||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
|
@ -125,14 +154,14 @@ def generate_chat_response(
|
|||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
model=state.processor_config.conversation.offline_chat.chat_model,
|
||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||
model=offline_chat_config.chat_model,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
tokenizer_name=conversation_config.tokenizer,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
elif openai_chat_config:
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = openai_chat_config.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
|
@ -141,8 +170,8 @@ def generate_chat_response(
|
|||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
|
||||
tokenizer_name=conversation_config.tokenizer if conversation_config else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -92,7 +92,7 @@ async def update(
|
|||
|
||||
if dict_to_update is not None:
|
||||
dict_to_update[file.filename] = (
|
||||
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read()
|
||||
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() # type: ignore
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
||||
|
@ -181,24 +181,25 @@ def configure_content(
|
|||
files: Optional[dict[str, dict[str, str]]],
|
||||
search_models: SearchModels,
|
||||
regenerate: bool = False,
|
||||
t: Optional[Union[state.SearchType, str]] = None,
|
||||
t: Optional[state.SearchType] = None,
|
||||
full_corpus: bool = True,
|
||||
user: KhojUser = None,
|
||||
) -> Optional[ContentIndex]:
|
||||
content_index = ContentIndex()
|
||||
|
||||
if t in [type.value for type in state.SearchType]:
|
||||
t = state.SearchType(t).value
|
||||
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||
logger.warning(f"🚨 Invalid search type: {t}")
|
||||
return None
|
||||
|
||||
assert type(t) == str or t == None, f"Invalid search type: {t}"
|
||||
search_type = t.value if t else None
|
||||
|
||||
if files is None:
|
||||
logger.warning(f"🚨 No files to process for {t} search.")
|
||||
logger.warning(f"🚨 No files to process for {search_type} search.")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize Org Notes Search
|
||||
if (t == None or t == state.SearchType.Org.value) and files["org"]:
|
||||
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
|
||||
logger.info("🦄 Setting up search for orgmode notes")
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
text_search.setup(
|
||||
|
@ -213,7 +214,7 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize Markdown Search
|
||||
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
|
||||
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
|
||||
logger.info("💎 Setting up search for markdown notes")
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
text_search.setup(
|
||||
|
@ -229,7 +230,7 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize PDF Search
|
||||
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
|
||||
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
||||
logger.info("🖨️ Setting up search for pdf")
|
||||
# Extract Entries, Generate PDF Embeddings
|
||||
text_search.setup(
|
||||
|
@ -245,7 +246,7 @@ def configure_content(
|
|||
|
||||
try:
|
||||
# Initialize Plaintext Search
|
||||
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
|
||||
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
|
||||
logger.info("📄 Setting up search for plaintext")
|
||||
# Extract Entries, Generate Plaintext Embeddings
|
||||
text_search.setup(
|
||||
|
@ -262,7 +263,7 @@ def configure_content(
|
|||
try:
|
||||
# Initialize Image Search
|
||||
if (
|
||||
(t == None or t == state.SearchType.Image.value)
|
||||
(search_type == None or search_type == state.SearchType.Image.value)
|
||||
and content_config
|
||||
and content_config.image
|
||||
and search_models.image_search
|
||||
|
@ -278,7 +279,7 @@ def configure_content(
|
|||
|
||||
try:
|
||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
|
||||
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
|
||||
logger.info("🐙 Setting up search for github")
|
||||
# Extract Entries, Generate Github Embeddings
|
||||
text_search.setup(
|
||||
|
@ -296,7 +297,7 @@ def configure_content(
|
|||
try:
|
||||
# Initialize Notion Search
|
||||
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||
if (t == None or t in state.SearchType.Notion.value) and notion_config:
|
||||
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
|
||||
logger.info("🔌 Setting up search for notion")
|
||||
text_search.setup(
|
||||
NotionToJsonl,
|
||||
|
|
|
@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import constants, state
|
||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
||||
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||
|
||||
|
||||
|
@ -83,7 +83,7 @@ if not state.demo:
|
|||
@web_client.get("/config", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def config_page(request: Request):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
user = request.user.object
|
||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
|
@ -100,9 +100,6 @@ if not state.demo:
|
|||
"github": ("github" in enabled_content),
|
||||
"notion": ("notion" in enabled_content),
|
||||
"plaintext": ("plaintext" in enabled_content),
|
||||
"enable_offline_model": False,
|
||||
"conversation_openai": False,
|
||||
"conversation_gpt4all": False,
|
||||
}
|
||||
|
||||
if state.content_index:
|
||||
|
@ -112,13 +109,17 @@ if not state.demo:
|
|||
}
|
||||
)
|
||||
|
||||
if state.processor_config and state.processor_config.conversation:
|
||||
successfully_configured.update(
|
||||
{
|
||||
"conversation_openai": state.processor_config.conversation.openai_model is not None,
|
||||
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
|
||||
}
|
||||
)
|
||||
enabled_chat_config = ConversationAdapters.get_enabled_conversation_settings(user)
|
||||
|
||||
successfully_configured.update(
|
||||
{
|
||||
"conversation_openai": enabled_chat_config["openai"],
|
||||
"enable_offline_model": enabled_chat_config["offline_chat"],
|
||||
"conversation_gpt4all": state.gpt4all_processor_config.loaded_model is not None
|
||||
if state.gpt4all_processor_config
|
||||
else False,
|
||||
}
|
||||
)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"config.html",
|
||||
|
@ -127,6 +128,7 @@ if not state.demo:
|
|||
"current_config": current_config,
|
||||
"current_model_state": successfully_configured,
|
||||
"anonymous_mode": state.anonymous_mode,
|
||||
"username": user.username if user else None,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -204,22 +206,22 @@ if not state.demo:
|
|||
)
|
||||
|
||||
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def conversation_processor_config_page(request: Request):
|
||||
default_copy = constants.default_config.copy()
|
||||
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
|
||||
default_openai_config = OpenAIProcessorConfig(
|
||||
api_key="",
|
||||
chat_model=default_processor_config["chat-model"],
|
||||
)
|
||||
user = request.user.object
|
||||
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||
|
||||
if openai_config:
|
||||
current_processor_openai_config = OpenAIProcessorConfig(
|
||||
api_key=openai_config.api_key,
|
||||
chat_model=openai_config.chat_model,
|
||||
)
|
||||
else:
|
||||
current_processor_openai_config = OpenAIProcessorConfig(
|
||||
api_key="",
|
||||
chat_model="gpt-3.5-turbo",
|
||||
)
|
||||
|
||||
current_processor_openai_config = (
|
||||
state.config.processor.conversation.openai
|
||||
if state.config
|
||||
and state.config.processor
|
||||
and state.config.processor.conversation
|
||||
and state.config.processor.conversation.openai
|
||||
else default_openai_config
|
||||
)
|
||||
current_processor_openai_config = json.loads(current_processor_openai_config.json())
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
|
|
@ -236,6 +236,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
|||
"image_score": f"{hit['image_score']:.9f}",
|
||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||
},
|
||||
"corpus_id": hit["corpus_id"],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
|
|
@ -14,10 +14,9 @@ from asgiref.sync import sync_to_async
|
|||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
||||
from khoj.utils.config import TextSearchModel
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
|
||||
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||
from database.adapters import EmbeddingsAdapters
|
||||
|
@ -36,36 +35,6 @@ search_type_to_embeddings_type = {
|
|||
}
|
||||
|
||||
|
||||
def initialize_model(search_config: TextSearchConfig):
|
||||
"Initialize model for semantic search on text"
|
||||
torch.set_num_threads(4)
|
||||
|
||||
# If model directory is configured
|
||||
if search_config.model_directory:
|
||||
# Convert model directory to absolute path
|
||||
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
|
||||
# Create model directory if it doesn't exist
|
||||
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# The bi-encoder encodes all entries to use for semantic search
|
||||
bi_encoder = load_model(
|
||||
model_dir=search_config.model_directory,
|
||||
model_name=search_config.encoder,
|
||||
model_type=search_config.encoder_type or SentenceTransformer,
|
||||
device=f"{state.device}",
|
||||
)
|
||||
|
||||
# The cross-encoder re-ranks the results to improve quality
|
||||
cross_encoder = load_model(
|
||||
model_dir=search_config.model_directory,
|
||||
model_name=search_config.cross_encoder,
|
||||
model_type=CrossEncoder,
|
||||
device=f"{state.device}",
|
||||
)
|
||||
|
||||
return TextSearchModel(bi_encoder, cross_encoder)
|
||||
|
||||
|
||||
def extract_entries(jsonl_file) -> List[Entry]:
|
||||
"Load entries from compressed jsonl"
|
||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||
|
@ -176,6 +145,7 @@ def collate_results(hits, dedupe=True):
|
|||
{
|
||||
"entry": hit.raw,
|
||||
"score": hit.distance,
|
||||
"corpus_id": str(hit.corpus_id),
|
||||
"additional": {
|
||||
"file": hit.file_path,
|
||||
"compiled": hit.compiled,
|
||||
|
@ -185,6 +155,28 @@ def collate_results(hits, dedupe=True):
|
|||
)
|
||||
|
||||
|
||||
def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||
hit_ids = set()
|
||||
for hit in hits:
|
||||
if hit.corpus_id in hit_ids:
|
||||
continue
|
||||
|
||||
else:
|
||||
hit_ids.add(hit.corpus_id)
|
||||
yield SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": hit.entry,
|
||||
"score": hit.score,
|
||||
"corpus_id": hit.corpus_id,
|
||||
"additional": {
|
||||
"file": hit.additional["file"],
|
||||
"compiled": hit.additional["compiled"],
|
||||
"heading": hit.additional["heading"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query):
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
hits = cross_encoder_score(query, hits)
|
||||
|
|
|
@ -5,8 +5,7 @@ from enum import Enum
|
|||
import logging
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
||||
from typing import TYPE_CHECKING, List, Optional, Union, Any
|
||||
from khoj.processor.conversation.gpt4all.utils import download_model
|
||||
|
||||
# External Packages
|
||||
|
@ -19,9 +18,7 @@ logger = logging.getLogger(__name__)
|
|||
# Internal Packages
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
|
@ -79,31 +76,15 @@ class GPT4AllProcessorConfig:
|
|||
loaded_model: Union[Any, None] = None
|
||||
|
||||
|
||||
class ConversationProcessorConfigModel:
|
||||
class GPT4AllProcessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
conversation_config: ConversationProcessorConfig,
|
||||
chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||
):
|
||||
self.openai_model = conversation_config.openai
|
||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||
self.offline_chat = conversation_config.offline_chat or OfflineChatProcessorConfig()
|
||||
self.max_prompt_size = conversation_config.max_prompt_size
|
||||
self.tokenizer = conversation_config.tokenizer
|
||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
||||
if self.offline_chat.enable_offline_chat:
|
||||
try:
|
||||
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
|
||||
except Exception as e:
|
||||
self.offline_chat.enable_offline_chat = False
|
||||
self.gpt4all_model.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
else:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorConfigModel:
|
||||
conversation: Union[ConversationProcessorConfigModel, None] = None
|
||||
self.chat_model = chat_model
|
||||
self.loaded_model = None
|
||||
try:
|
||||
self.loaded_model = download_model(self.chat_model)
|
||||
except ValueError as e:
|
||||
self.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
|
|
|
@ -8,136 +8,14 @@ telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
|||
content_directory = "~/.khoj/content/"
|
||||
|
||||
empty_config = {
|
||||
"content-type": {
|
||||
"org": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
|
||||
"index-heading-entries": False,
|
||||
},
|
||||
"markdown": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
|
||||
},
|
||||
"pdf": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
|
||||
},
|
||||
"plaintext": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
|
||||
},
|
||||
},
|
||||
"search-type": {
|
||||
"symmetric": {
|
||||
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"model_directory": "~/.khoj/search/symmetric/",
|
||||
},
|
||||
"asymmetric": {
|
||||
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"model_directory": "~/.khoj/search/asymmetric/",
|
||||
},
|
||||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||
},
|
||||
"processor": {
|
||||
"conversation": {
|
||||
"openai": {
|
||||
"api-key": None,
|
||||
"chat-model": "gpt-3.5-turbo",
|
||||
},
|
||||
"offline-chat": {
|
||||
"enable-offline-chat": False,
|
||||
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||
},
|
||||
"tokenizer": None,
|
||||
"max-prompt-size": None,
|
||||
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# default app config to use
|
||||
default_config = {
|
||||
"content-type": {
|
||||
"org": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
|
||||
"index-heading-entries": False,
|
||||
},
|
||||
"markdown": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
|
||||
},
|
||||
"pdf": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
|
||||
},
|
||||
"image": {
|
||||
"input-directories": None,
|
||||
"input-filter": None,
|
||||
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
|
||||
"batch-size": 50,
|
||||
"use-xmp-metadata": False,
|
||||
},
|
||||
"github": {
|
||||
"pat-token": None,
|
||||
"repos": [],
|
||||
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
|
||||
},
|
||||
"notion": {
|
||||
"token": None,
|
||||
"compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt",
|
||||
},
|
||||
"plaintext": {
|
||||
"input-files": None,
|
||||
"input-filter": None,
|
||||
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
|
||||
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
|
||||
},
|
||||
},
|
||||
"search-type": {
|
||||
"symmetric": {
|
||||
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"model_directory": "~/.khoj/search/symmetric/",
|
||||
},
|
||||
"asymmetric": {
|
||||
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"model_directory": "~/.khoj/search/asymmetric/",
|
||||
},
|
||||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||
},
|
||||
"processor": {
|
||||
"conversation": {
|
||||
"openai": {
|
||||
"api-key": None,
|
||||
"chat-model": "gpt-3.5-turbo",
|
||||
},
|
||||
"offline-chat": {
|
||||
"enable-offline-chat": False,
|
||||
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||
},
|
||||
"tokenizer": None,
|
||||
"max-prompt-size": None,
|
||||
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ from time import perf_counter
|
|||
import torch
|
||||
from typing import Optional, Union, TYPE_CHECKING
|
||||
import uuid
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import constants
|
||||
|
@ -29,6 +30,28 @@ if TYPE_CHECKING:
|
|||
from khoj.utils.rawconfig import AppConfig
|
||||
|
||||
|
||||
class AsyncIteratorWrapper:
|
||||
def __init__(self, obj):
|
||||
self._it = iter(obj)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
value = await self.next_async()
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
return value
|
||||
|
||||
@sync_to_async
|
||||
def next_async(self):
|
||||
try:
|
||||
return next(self._it)
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
def is_none_or_empty(item):
|
||||
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
||||
|
||||
|
|
|
@ -67,13 +67,6 @@ class ContentConfig(ConfigBase):
|
|||
notion: Optional[NotionContentConfig]
|
||||
|
||||
|
||||
class TextSearchConfig(ConfigBase):
|
||||
encoder: str
|
||||
cross_encoder: str
|
||||
encoder_type: Optional[str]
|
||||
model_directory: Optional[Path]
|
||||
|
||||
|
||||
class ImageSearchConfig(ConfigBase):
|
||||
encoder: str
|
||||
encoder_type: Optional[str]
|
||||
|
@ -81,8 +74,6 @@ class ImageSearchConfig(ConfigBase):
|
|||
|
||||
|
||||
class SearchConfig(ConfigBase):
|
||||
asymmetric: Optional[TextSearchConfig]
|
||||
symmetric: Optional[TextSearchConfig]
|
||||
image: Optional[ImageSearchConfig]
|
||||
|
||||
|
||||
|
@ -97,11 +88,10 @@ class OfflineChatProcessorConfig(ConfigBase):
|
|||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
conversation_logfile: Path
|
||||
openai: Optional[OpenAIProcessorConfig]
|
||||
offline_chat: Optional[OfflineChatProcessorConfig]
|
||||
max_prompt_size: Optional[int]
|
||||
tokenizer: Optional[str]
|
||||
openai: Optional[OpenAIProcessorConfig] = None
|
||||
offline_chat: Optional[OfflineChatProcessorConfig] = None
|
||||
max_prompt_size: Optional[int] = None
|
||||
tokenizer: Optional[str] = None
|
||||
|
||||
|
||||
class ProcessorConfig(ConfigBase):
|
||||
|
@ -125,6 +115,7 @@ class SearchResponse(ConfigBase):
|
|||
score: float
|
||||
cross_score: Optional[float]
|
||||
additional: Optional[dict]
|
||||
corpus_id: str
|
||||
|
||||
|
||||
class Entry:
|
||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import LRU
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
||||
|
@ -21,7 +21,7 @@ search_models = SearchModels()
|
|||
embeddings_model = EmbeddingsModel()
|
||||
cross_encoder_model = CrossEncoderModel()
|
||||
content_index = ContentIndex()
|
||||
processor_config = ProcessorConfigModel()
|
||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||
config_file: Path = None
|
||||
verbose: int = 0
|
||||
host: str = None
|
||||
|
|
|
@ -5,7 +5,6 @@ from pathlib import Path
|
|||
import pytest
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI
|
||||
import factory
|
||||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
@ -13,7 +12,7 @@ app = FastAPI()
|
|||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
|
@ -21,13 +20,8 @@ from khoj.utils.constants import web_directory
|
|||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
ConversationProcessorConfig,
|
||||
OfflineChatProcessorConfig,
|
||||
OpenAIProcessorConfig,
|
||||
ProcessorConfig,
|
||||
ImageContentConfig,
|
||||
SearchConfig,
|
||||
TextSearchConfig,
|
||||
ImageSearchConfig,
|
||||
)
|
||||
from khoj.utils import state, fs_syncer
|
||||
|
@ -42,42 +36,25 @@ from database.models import (
|
|||
GithubRepoConfig,
|
||||
)
|
||||
|
||||
from tests.helpers import (
|
||||
UserFactory,
|
||||
ConversationProcessorConfigFactory,
|
||||
OpenAIProcessorConversationConfigFactory,
|
||||
OfflineChatProcessorConversationConfigFactory,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_db_access_for_all_tests(db):
|
||||
pass
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
|
||||
username = factory.Faker("name")
|
||||
email = factory.Faker("email")
|
||||
password = factory.Faker("password")
|
||||
uuid = factory.Faker("uuid4")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_config() -> SearchConfig:
|
||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
search_config = SearchConfig()
|
||||
|
||||
search_config.symmetric = TextSearchConfig(
|
||||
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory=model_dir / "symmetric/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
search_config.asymmetric = TextSearchConfig(
|
||||
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
model_directory=model_dir / "asymmetric/",
|
||||
encoder_type=None,
|
||||
)
|
||||
|
||||
search_config.image = ImageSearchConfig(
|
||||
encoder="sentence-transformers/clip-ViT-B-32",
|
||||
model_directory=model_dir / "image/",
|
||||
|
@ -177,55 +154,48 @@ def md_content_config():
|
|||
return markdown_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def processor_config(tmp_path_factory):
|
||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
||||
processor_dir = tmp_path_factory.mktemp("processor")
|
||||
|
||||
# The conversation processor is the only configured processor
|
||||
# It needs an OpenAI API key to work.
|
||||
if not openai_api_key:
|
||||
return
|
||||
|
||||
# Setup conversation processor, if OpenAI API key is set
|
||||
processor_config = ProcessorConfig()
|
||||
processor_config.conversation = ConversationProcessorConfig(
|
||||
openai=OpenAIProcessorConfig(api_key=openai_api_key),
|
||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||
)
|
||||
|
||||
return processor_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def processor_config_offline_chat(tmp_path_factory):
|
||||
processor_dir = tmp_path_factory.mktemp("processor")
|
||||
|
||||
# Setup conversation processor
|
||||
processor_config = ProcessorConfig()
|
||||
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
|
||||
processor_config.conversation = ConversationProcessorConfig(
|
||||
offline_chat=offline_chat,
|
||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||
)
|
||||
|
||||
return processor_config
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=default_user2,
|
||||
)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
all_files = fs_syncer.collect_files()
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
OpenAIProcessorConversationConfigFactory(user=default_user2)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
configure_routes(app)
|
||||
configure_middleware(app)
|
||||
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||
# Initialize app state
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
# Initialize Processor from Config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
OpenAIProcessorConversationConfigFactory(user=default_user2)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -249,7 +219,6 @@ def fastapi_app():
|
|||
def client(
|
||||
content_config: ContentConfig,
|
||||
search_config: SearchConfig,
|
||||
processor_config: ProcessorConfig,
|
||||
default_user: KhojUser,
|
||||
):
|
||||
state.config.content_type = content_config
|
||||
|
@ -274,7 +243,7 @@ def client(
|
|||
user=default_user,
|
||||
)
|
||||
|
||||
state.processor_config = configure_processor(processor_config)
|
||||
ConversationProcessorConfigFactory(user=default_user)
|
||||
state.anonymous_mode = True
|
||||
|
||||
configure_routes(app)
|
||||
|
@ -286,25 +255,32 @@ def client(
|
|||
@pytest.fixture(scope="function")
|
||||
def client_offline_chat(
|
||||
search_config: SearchConfig,
|
||||
processor_config_offline_chat: ProcessorConfig,
|
||||
content_config: ContentConfig,
|
||||
md_content_config,
|
||||
default_user2: KhojUser,
|
||||
):
|
||||
# Initialize app state
|
||||
state.config.content_type = md_content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
|
||||
LocalMarkdownConfig.objects.create(
|
||||
input_files=None,
|
||||
input_filter=["tests/data/markdown/*.markdown"],
|
||||
user=default_user2,
|
||||
)
|
||||
|
||||
# Index Markdown Content for Search
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
||||
state.content_index = configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models
|
||||
all_files = fs_syncer.collect_files(user=default_user2)
|
||||
configure_content(
|
||||
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||
)
|
||||
|
||||
# Initialize Processor from Config
|
||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
||||
ConversationProcessorConfigFactory(user=default_user2)
|
||||
OfflineChatProcessorConversationConfigFactory(user=default_user2)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
configure_routes(app)
|
||||
|
|
51
tests/helpers.py
Normal file
51
tests/helpers.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
import factory
|
||||
import os
|
||||
|
||||
from database.models import (
|
||||
KhojUser,
|
||||
ConversationProcessorConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
Conversation,
|
||||
)
|
||||
|
||||
|
||||
class UserFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = KhojUser
|
||||
|
||||
username = factory.Faker("name")
|
||||
email = factory.Faker("email")
|
||||
password = factory.Faker("password")
|
||||
uuid = factory.Faker("uuid4")
|
||||
|
||||
|
||||
class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = ConversationProcessorConfig
|
||||
|
||||
max_prompt_size = 2000
|
||||
tokenizer = None
|
||||
|
||||
|
||||
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = OfflineChatProcessorConversationConfig
|
||||
|
||||
enable_offline_chat = True
|
||||
chat_model = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = OpenAIProcessorConversationConfig
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
chat_model = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
class ConversationFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Conversation
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
|
@ -119,7 +119,12 @@ def test_get_configured_types_via_api(client, sample_org_data):
|
|||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
||||
# Arrange
|
||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/config/types")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert response.json() == ["all", "org", "image"]
|
||||
|
||||
|
||||
|
|
|
@ -9,8 +9,7 @@ from faker import Faker
|
|||
# Internal Packages
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils import state
|
||||
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
SKIP_TESTS = True
|
||||
pytestmark = pytest.mark.skipif(
|
||||
|
@ -23,7 +22,7 @@ fake = Faker()
|
|||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
def populate_chat_history(message_list, user):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, llm_message, context in message_list:
|
||||
|
@ -33,14 +32,15 @@ def populate_chat_history(message_list):
|
|||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
|
||||
# Update Conversation Metadata Logs in Application State
|
||||
state.processor_config.conversation.meta_log = conversation_log
|
||||
# Update Conversation Metadata Logs in Database
|
||||
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
|
@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
|
@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
|
@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
|
@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
|||
reason="Chat director not capable of answering this question yet because it requires extract_questions",
|
||||
)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
|
@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_using_general_command(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_general_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_using_file_filter(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_file_filter(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
|
||||
|
@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_not_known_using_notes_command(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat):
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
|
@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat):
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
|
@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(
|
||||
|
@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
|||
|
||||
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_chat_history_very_long(client_offline_chat):
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat):
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
|
|
|
@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts
|
|||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.utils import state
|
||||
|
||||
from tests.helpers import ConversationFactory
|
||||
from database.models import KhojUser
|
||||
|
||||
# Initialize variables for tests
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
@ -23,7 +23,7 @@ if api_key is None:
|
|||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
def populate_chat_history(message_list, user=None):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message, context in message_list:
|
||||
|
@ -33,13 +33,14 @@ def populate_chat_history(message_list):
|
|||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
)
|
||||
|
||||
# Update Conversation Metadata Logs in Application State
|
||||
state.processor_config.conversation.meta_log = conversation_log
|
||||
# Update Conversation Metadata Logs in Database
|
||||
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||
|
||||
|
||||
# Tests
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
|
@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history(chat_client):
|
||||
def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_currently_retrieved_content(chat_client):
|
||||
def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
|
@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
|
@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
|
@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
||||
def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
|
@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_using_general_command(chat_client):
|
||||
def test_answer_using_general_command(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_retrieved_content_using_notes_command(chat_client):
|
||||
def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -205,15 +213,16 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_not_known_using_notes_command(chat_client):
|
||||
def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser):
|
||||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -223,6 +232,7 @@ def test_answer_not_known_using_notes_command(chat_client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_current_date_awareness(chat_client):
|
||||
|
@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("2023-04-01")
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
|
||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
|
@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client):
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(
|
||||
|
@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
|
||||
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
|
@ -301,15 +316,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
|||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
||||
def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list)
|
||||
populate_chat_history(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -324,6 +340,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
|
@ -340,10 +357,12 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_using_file_filter(chat_client):
|
||||
"Chat should be able to use search filters in the query"
|
||||
# Act
|
||||
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
|
||||
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
|
|
|
@ -13,12 +13,11 @@ from khoj.search_type import text_search
|
|||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.fs_syncer import get_org_files, collect_files
|
||||
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
|
||||
|
||||
# Test
|
||||
|
|
Loading…
Reference in a new issue