mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
[Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511)
- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration. - There does _seem_ to be some regression in chat quality, which is most likely attributable to search results. This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
This commit is contained in:
parent
a8a82d274a
commit
4b6ec248a6
24 changed files with 719 additions and 626 deletions
|
@ -2,3 +2,5 @@
|
||||||
DJANGO_SETTINGS_MODULE = app.settings
|
DJANGO_SETTINGS_MODULE = app.settings
|
||||||
pythonpath = . src
|
pythonpath = . src
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
|
markers =
|
||||||
|
chatquality: marks tests as chatquality (deselect with '-m "not chatquality"')
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Type, TypeVar, List
|
from typing import Type, TypeVar, List
|
||||||
import uuid
|
|
||||||
from datetime import date
|
from datetime import date
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
@ -21,6 +20,13 @@ from database.models import (
|
||||||
GithubConfig,
|
GithubConfig,
|
||||||
Embeddings,
|
Embeddings,
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
|
Conversation,
|
||||||
|
ConversationProcessorConfig,
|
||||||
|
OpenAIProcessorConversationConfig,
|
||||||
|
OfflineChatProcessorConversationConfig,
|
||||||
|
)
|
||||||
|
from khoj.utils.rawconfig import (
|
||||||
|
ConversationProcessorConfig as UserConversationProcessorConfig,
|
||||||
)
|
)
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:
|
||||||
|
|
||||||
|
|
||||||
async def create_google_user(token: dict) -> KhojUser:
|
async def create_google_user(token: dict) -> KhojUser:
|
||||||
user_info = token.get("userinfo")
|
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
|
||||||
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
|
|
||||||
await user.asave()
|
await user.asave()
|
||||||
await GoogleUser.objects.acreate(
|
await GoogleUser.objects.acreate(
|
||||||
sub=user_info.get("sub"),
|
sub=token.get("sub"),
|
||||||
azp=user_info.get("azp"),
|
azp=token.get("azp"),
|
||||||
email=user_info.get("email"),
|
email=token.get("email"),
|
||||||
name=user_info.get("name"),
|
name=token.get("name"),
|
||||||
given_name=user_info.get("given_name"),
|
given_name=token.get("given_name"),
|
||||||
family_name=user_info.get("family_name"),
|
family_name=token.get("family_name"),
|
||||||
picture=user_info.get("picture"),
|
picture=token.get("picture"),
|
||||||
locale=user_info.get("locale"),
|
locale=token.get("locale"),
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationAdapters:
|
||||||
|
@staticmethod
|
||||||
|
def get_conversation_by_user(user: KhojUser):
|
||||||
|
conversation = Conversation.objects.filter(user=user)
|
||||||
|
if conversation.exists():
|
||||||
|
return conversation.first()
|
||||||
|
return Conversation.objects.create(user=user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_conversation_by_user(user: KhojUser):
|
||||||
|
conversation = Conversation.objects.filter(user=user)
|
||||||
|
if await conversation.aexists():
|
||||||
|
return await conversation.afirst()
|
||||||
|
return await Conversation.objects.acreate(user=user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_any_conversation_config(user: KhojUser):
|
||||||
|
return ConversationProcessorConfig.objects.filter(user=user).exists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_openai_conversation_config(user: KhojUser):
|
||||||
|
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_offline_chat_conversation_config(user: KhojUser):
|
||||||
|
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_valid_offline_conversation_config(user: KhojUser):
|
||||||
|
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def has_valid_openai_conversation_config(user: KhojUser):
|
||||||
|
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_conversation_config(user: KhojUser):
|
||||||
|
return ConversationProcessorConfig.objects.filter(user=user).first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_conversation(user: KhojUser, conversation_log: dict):
|
||||||
|
conversation = Conversation.objects.filter(user=user)
|
||||||
|
if conversation.exists():
|
||||||
|
conversation.update(conversation_log=conversation_log)
|
||||||
|
else:
|
||||||
|
Conversation.objects.create(user=user, conversation_log=conversation_log)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
|
||||||
|
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
|
||||||
|
conversation_config.max_prompt_size = new_config.max_prompt_size
|
||||||
|
conversation_config.tokenizer = new_config.tokenizer
|
||||||
|
conversation_config.save()
|
||||||
|
|
||||||
|
if new_config.openai:
|
||||||
|
default_values = {
|
||||||
|
"api_key": new_config.openai.api_key,
|
||||||
|
}
|
||||||
|
if new_config.openai.chat_model:
|
||||||
|
default_values["chat_model"] = new_config.openai.chat_model
|
||||||
|
|
||||||
|
OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
|
||||||
|
|
||||||
|
if new_config.offline_chat:
|
||||||
|
default_values = {
|
||||||
|
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
|
||||||
|
}
|
||||||
|
|
||||||
|
if new_config.offline_chat.chat_model:
|
||||||
|
default_values["chat_model"] = new_config.offline_chat.chat_model
|
||||||
|
|
||||||
|
OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_enabled_conversation_settings(user: KhojUser):
|
||||||
|
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||||
|
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"openai": True if openai_config is not None else False,
|
||||||
|
"offline_chat": True
|
||||||
|
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
|
||||||
|
else False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_conversation_config(user: KhojUser):
|
||||||
|
ConversationProcessorConfig.objects.filter(user=user).delete()
|
||||||
|
ConversationAdapters.clear_openai_conversation_config(user)
|
||||||
|
ConversationAdapters.clear_offline_chat_conversation_config(user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_openai_conversation_config(user: KhojUser):
|
||||||
|
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clear_offline_chat_conversation_config(user: KhojUser):
|
||||||
|
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def has_offline_chat(user: KhojUser):
|
||||||
|
return await OfflineChatProcessorConversationConfig.objects.filter(
|
||||||
|
user=user, enable_offline_chat=True
|
||||||
|
).aexists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_offline_chat(user: KhojUser):
|
||||||
|
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def has_openai_chat(user: KhojUser):
|
||||||
|
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_openai_chat(user: KhojUser):
|
||||||
|
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsAdapters:
|
class EmbeddingsAdapters:
|
||||||
word_filer = WordFilter()
|
word_filer = WordFilter()
|
||||||
file_filter = FileFilter()
|
file_filter = FileFilter()
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-10-18 05:31
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0006_embeddingsdates"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="conversationprocessorconfig",
|
||||||
|
name="conversation",
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name="conversationprocessorconfig",
|
||||||
|
name="enable_offline_chat",
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="conversationprocessorconfig",
|
||||||
|
name="max_prompt_size",
|
||||||
|
field=models.IntegerField(blank=True, default=None, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="conversationprocessorconfig",
|
||||||
|
name="tokenizer",
|
||||||
|
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="conversationprocessorconfig",
|
||||||
|
name="user",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
|
||||||
|
),
|
||||||
|
preserve_default=False,
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="OpenAIProcessorConversationConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("api_key", models.CharField(max_length=200)),
|
||||||
|
("chat_model", models.CharField(max_length=200)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="OfflineChatProcessorConversationConfig",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("enable_offline_chat", models.BooleanField(default=False)),
|
||||||
|
("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="Conversation",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
("conversation_log", models.JSONField()),
|
||||||
|
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-10-18 16:46
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0007_remove_conversationprocessorconfig_conversation_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="conversation",
|
||||||
|
name="conversation_log",
|
||||||
|
field=models.JSONField(default=dict),
|
||||||
|
),
|
||||||
|
]
|
|
@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
conversation = models.JSONField()
|
api_key = models.CharField(max_length=200)
|
||||||
|
chat_model = models.CharField(max_length=200)
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class OfflineChatProcessorConversationConfig(BaseModel):
|
||||||
enable_offline_chat = models.BooleanField(default=False)
|
enable_offline_chat = models.BooleanField(default=False)
|
||||||
|
chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin")
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationProcessorConfig(BaseModel):
|
||||||
|
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||||
|
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(BaseModel):
|
||||||
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
conversation_log = models.JSONField(default=dict)
|
||||||
|
|
||||||
|
|
||||||
class Embeddings(BaseModel):
|
class Embeddings(BaseModel):
|
||||||
|
|
|
@ -23,12 +23,10 @@ from starlette.authentication import (
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import (
|
from khoj.utils.config import (
|
||||||
SearchType,
|
SearchType,
|
||||||
ProcessorConfigModel,
|
|
||||||
ConversationProcessorConfigModel,
|
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
from khoj.utils.helpers import merge_dicts
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser
|
||||||
from database.adapters import get_all_users
|
from database.adapters import get_all_users
|
||||||
|
@ -98,13 +96,6 @@ def configure_server(
|
||||||
# Update Config
|
# Update Config
|
||||||
state.config = config
|
state.config = config
|
||||||
|
|
||||||
# Initialize Processor from Config
|
|
||||||
try:
|
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"🚨 Failed to configure processor", exc_info=True)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
# Initialize Search Models from Config and initialize content
|
# Initialize Search Models from Config and initialize content
|
||||||
try:
|
try:
|
||||||
state.config_lock.acquire()
|
state.config_lock.acquire()
|
||||||
|
@ -190,103 +181,6 @@ def configure_search_types(config: FullConfig):
|
||||||
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
return Enum("SearchType", merge_dicts(core_search_types, {}))
|
||||||
|
|
||||||
|
|
||||||
def configure_processor(
|
|
||||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
|
||||||
):
|
|
||||||
if not processor_config:
|
|
||||||
logger.warning("🚨 No Processor configuration available.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
processor = ProcessorConfigModel()
|
|
||||||
|
|
||||||
# Initialize Conversation Processor
|
|
||||||
logger.info("💬 Setting up conversation processor")
|
|
||||||
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
|
|
||||||
|
|
||||||
return processor
|
|
||||||
|
|
||||||
|
|
||||||
def configure_conversation_processor(
|
|
||||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
not processor_config
|
|
||||||
or not processor_config.conversation
|
|
||||||
or not processor_config.conversation.conversation_logfile
|
|
||||||
):
|
|
||||||
default_config = constants.default_config
|
|
||||||
default_conversation_logfile = resolve_absolute_path(
|
|
||||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
|
||||||
)
|
|
||||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
|
||||||
conversation_config = processor_config.conversation if processor_config else None
|
|
||||||
conversation_processor = ConversationProcessorConfigModel(
|
|
||||||
conversation_config=ConversationProcessorConfig(
|
|
||||||
conversation_logfile=conversation_logfile,
|
|
||||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
|
||||||
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
|
||||||
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
|
|
||||||
tokenizer=conversation_config.tokenizer if conversation_config else None,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
conversation_processor = ConversationProcessorConfigModel(
|
|
||||||
conversation_config=processor_config.conversation,
|
|
||||||
)
|
|
||||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
|
||||||
|
|
||||||
# Load Conversation Logs from Disk
|
|
||||||
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
|
|
||||||
conversation_processor.meta_log = state_processor_config.conversation.meta_log
|
|
||||||
conversation_processor.chat_session = state_processor_config.conversation.chat_session
|
|
||||||
logger.debug(f"Loaded conversation logs from state")
|
|
||||||
return conversation_processor
|
|
||||||
|
|
||||||
if conversation_logfile.is_file():
|
|
||||||
# Load Metadata Logs from Conversation Logfile
|
|
||||||
with conversation_logfile.open("r") as f:
|
|
||||||
conversation_processor.meta_log = json.load(f)
|
|
||||||
logger.debug(f"Loaded conversation logs from {conversation_logfile}")
|
|
||||||
else:
|
|
||||||
# Initialize Conversation Logs
|
|
||||||
conversation_processor.meta_log = {}
|
|
||||||
conversation_processor.chat_session = []
|
|
||||||
|
|
||||||
return conversation_processor
|
|
||||||
|
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(17).minutes)
|
|
||||||
def save_chat_session():
|
|
||||||
# No need to create empty log file
|
|
||||||
if not (
|
|
||||||
state.processor_config
|
|
||||||
and state.processor_config.conversation
|
|
||||||
and state.processor_config.conversation.meta_log
|
|
||||||
and state.processor_config.conversation.chat_session
|
|
||||||
):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Summarize Conversation Logs for this Session
|
|
||||||
conversation_log = state.processor_config.conversation.meta_log
|
|
||||||
session = {
|
|
||||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
|
||||||
"session-end": len(conversation_log["chat"]),
|
|
||||||
}
|
|
||||||
if "session" in conversation_log:
|
|
||||||
conversation_log["session"].append(session)
|
|
||||||
else:
|
|
||||||
conversation_log["session"] = [session]
|
|
||||||
|
|
||||||
# Save Conversation Metadata Logs to Disk
|
|
||||||
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
|
|
||||||
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
|
|
||||||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
|
||||||
json.dump(conversation_log, logfile, indent=2)
|
|
||||||
|
|
||||||
state.processor_config.conversation.chat_session = []
|
|
||||||
logger.info("📩 Saved current chat session to conversation logs")
|
|
||||||
|
|
||||||
|
|
||||||
@schedule.repeat(schedule.every(59).minutes)
|
@schedule.repeat(schedule.every(59).minutes)
|
||||||
def upload_telemetry():
|
def upload_telemetry():
|
||||||
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
|
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
|
||||||
|
|
|
@ -3,6 +3,11 @@
|
||||||
|
|
||||||
<div class="page">
|
<div class="page">
|
||||||
<div class="section">
|
<div class="section">
|
||||||
|
{% if anonymous_mode == False %}
|
||||||
|
<div>
|
||||||
|
Logged in as {{ username }}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
<h2 class="section-title">Plugins</h2>
|
<h2 class="section-title">Plugins</h2>
|
||||||
<div class="section-cards">
|
<div class="section-cards">
|
||||||
<div class="card">
|
<div class="card">
|
||||||
|
@ -257,8 +262,8 @@
|
||||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Offline Chat
|
Offline Chat
|
||||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_model_state.enable_offline_model and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
{% if current_model_state.enable_offline_model and not current_model_state.conversation_gpt4all %}
|
||||||
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</h3>
|
</h3>
|
||||||
|
@ -266,12 +271,12 @@
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Setup offline chat</p>
|
<p class="card-description">Setup offline chat</p>
|
||||||
</div>
|
</div>
|
||||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
<div id="clear-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}enabled{% else %}disabled{% endif %}">
|
||||||
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||||
Disable
|
Disable
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
<div id="set-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}disabled{% else %}enabled{% endif %}">
|
||||||
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
||||||
Enable
|
Enable
|
||||||
</button>
|
</button>
|
||||||
|
|
|
@ -8,21 +8,20 @@ from typing import List, Optional, Union, Any
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request, Depends
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_server
|
from khoj.configure import configure_server
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils.config import TextSearchModel
|
from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
FullConfig,
|
FullConfig,
|
||||||
ProcessorConfig,
|
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchResponse,
|
SearchResponse,
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
|
@ -32,16 +31,16 @@ from khoj.utils.rawconfig import (
|
||||||
ConversationProcessorConfig,
|
ConversationProcessorConfig,
|
||||||
OfflineChatProcessorConfig,
|
OfflineChatProcessorConfig,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils import state, constants
|
from khoj.utils import state, constants
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
from khoj.utils.helpers import AsyncIteratorWrapper
|
||||||
from fastapi.responses import StreamingResponse, Response
|
from fastapi.responses import StreamingResponse, Response
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
perform_chat_checks,
|
perform_chat_checks,
|
||||||
generate_chat_response,
|
agenerate_chat_response,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
|
is_ready_to_chat,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.prompts import help_message
|
from khoj.processor.conversation.prompts import help_message
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
|
@ -49,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
|
|
||||||
from database import adapters
|
from database import adapters
|
||||||
from database.adapters import EmbeddingsAdapters
|
from database.adapters import EmbeddingsAdapters, ConversationAdapters
|
||||||
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,6 +113,8 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
|
||||||
user=user,
|
user=user,
|
||||||
token=config.content_type.notion.token,
|
token=config.content_type.notion.token,
|
||||||
)
|
)
|
||||||
|
if config.processor and config.processor.conversation:
|
||||||
|
ConversationAdapters.set_conversation_processor_config(user, config.processor.conversation)
|
||||||
|
|
||||||
|
|
||||||
# If it's a demo instance, prevent updating any of the configuration.
|
# If it's a demo instance, prevent updating any of the configuration.
|
||||||
|
@ -123,8 +124,6 @@ if not state.demo:
|
||||||
if state.config is None:
|
if state.config is None:
|
||||||
state.config = FullConfig()
|
state.config = FullConfig()
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||||
if state.processor_config is None:
|
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
|
||||||
|
|
||||||
@api.get("/config/data", response_model=FullConfig)
|
@api.get("/config/data", response_model=FullConfig)
|
||||||
@requires(["authenticated"], redirect="login_page")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
|
@ -238,28 +237,24 @@ if not state.demo:
|
||||||
)
|
)
|
||||||
|
|
||||||
content_object = map_config_to_object(content_type)
|
content_object = map_config_to_object(content_type)
|
||||||
|
if content_object is None:
|
||||||
|
raise ValueError(f"Invalid content type: {content_type}")
|
||||||
|
|
||||||
await content_object.objects.filter(user=user).adelete()
|
await content_object.objects.filter(user=user).adelete()
|
||||||
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
|
||||||
|
|
||||||
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
|
||||||
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def remove_processor_conversation_config_data(
|
async def remove_processor_conversation_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if (
|
user = request.user.object
|
||||||
not state.config
|
|
||||||
or not state.config.processor
|
|
||||||
or not state.config.processor.conversation
|
|
||||||
or not state.config.processor.conversation.openai
|
|
||||||
):
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
state.config.processor.conversation.openai = None
|
await sync_to_async(ConversationAdapters.clear_openai_conversation_config)(user)
|
||||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -269,11 +264,7 @@ if not state.demo:
|
||||||
metadata={"processor_conversation_type": "openai"},
|
metadata={"processor_conversation_type": "openai"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
@api.post("/config/data/content_type/{content_type}", status_code=200)
|
||||||
@requires(["authenticated"], redirect="login_page")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
|
@ -301,24 +292,17 @@ if not state.demo:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
async def set_processor_openai_config_data(
|
async def set_processor_openai_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
updated_config: Union[OpenAIProcessorConfig, None],
|
updated_config: Union[OpenAIProcessorConfig, None],
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
_initialize_config()
|
user = request.user.object
|
||||||
|
|
||||||
if not state.config.processor or not state.config.processor.conversation:
|
conversation_config = ConversationProcessorConfig(openai=updated_config)
|
||||||
default_config = constants.default_config
|
|
||||||
default_conversation_logfile = resolve_absolute_path(
|
|
||||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
|
||||||
)
|
|
||||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
|
||||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
|
||||||
|
|
||||||
assert state.config.processor.conversation is not None
|
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
|
||||||
state.config.processor.conversation.openai = updated_config
|
|
||||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -328,11 +312,7 @@ if not state.demo:
|
||||||
metadata={"processor_conversation_type": "conversation"},
|
metadata={"processor_conversation_type": "conversation"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
|
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
|
||||||
async def set_processor_enable_offline_chat_config_data(
|
async def set_processor_enable_offline_chat_config_data(
|
||||||
|
@ -341,24 +321,26 @@ if not state.demo:
|
||||||
offline_chat_model: Optional[str] = None,
|
offline_chat_model: Optional[str] = None,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
_initialize_config()
|
user = request.user.object
|
||||||
|
|
||||||
if not state.config.processor or not state.config.processor.conversation:
|
if enable_offline_chat:
|
||||||
default_config = constants.default_config
|
conversation_config = ConversationProcessorConfig(
|
||||||
default_conversation_logfile = resolve_absolute_path(
|
offline_chat=OfflineChatProcessorConfig(
|
||||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
enable_offline_chat=enable_offline_chat,
|
||||||
|
chat_model=offline_chat_model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
|
||||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
|
||||||
|
|
||||||
assert state.config.processor.conversation is not None
|
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
|
||||||
if state.config.processor.conversation.offline_chat is None:
|
|
||||||
state.config.processor.conversation.offline_chat = OfflineChatProcessorConfig()
|
|
||||||
|
|
||||||
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
|
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||||
if offline_chat_model is not None:
|
chat_model = offline_chat.chat_model
|
||||||
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
|
if state.gpt4all_processor_config is None:
|
||||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
||||||
|
|
||||||
|
else:
|
||||||
|
await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
|
||||||
|
state.gpt4all_processor_config = None
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -368,11 +350,7 @@ if not state.demo:
|
||||||
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
|
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
save_config_to_file_updated_state()
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
except Exception as e:
|
|
||||||
return {"status": "error", "message": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
# Create Routes
|
# Create Routes
|
||||||
|
@ -426,9 +404,6 @@ async def search(
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||||
return results
|
return results
|
||||||
if not state.search_models or not any(state.search_models.__dict__.values()):
|
|
||||||
logger.warning(f"No search models loaded. Configure a search model before initiating search")
|
|
||||||
return results
|
|
||||||
|
|
||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
|
@ -565,8 +540,6 @@ def update(
|
||||||
components.append("Search models")
|
components.append("Search models")
|
||||||
if state.content_index:
|
if state.content_index:
|
||||||
components.append("Content index")
|
components.append("Content index")
|
||||||
if state.processor_config:
|
|
||||||
components.append("Conversation processor")
|
|
||||||
components_msg = ", ".join(components)
|
components_msg = ", ".join(components)
|
||||||
logger.info(f"📪 {components_msg} updated via API")
|
logger.info(f"📪 {components_msg} updated via API")
|
||||||
|
|
||||||
|
@ -592,12 +565,11 @@ def chat_history(
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
perform_chat_checks()
|
user = request.user.object
|
||||||
|
perform_chat_checks(user)
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
meta_log = {}
|
meta_log = ConversationAdapters.get_conversation_by_user(user=user).conversation_log
|
||||||
if state.processor_config.conversation:
|
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -649,30 +621,35 @@ async def chat(
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
perform_chat_checks()
|
user = request.user.object
|
||||||
|
|
||||||
|
await is_ready_to_chat(user)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
|
|
||||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||||
|
|
||||||
|
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, q, (n or 5), conversation_command
|
request, meta_log, q, (n or 5), conversation_command
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||||
conversation_command = ConversationCommand.General
|
conversation_command = ConversationCommand.General
|
||||||
|
|
||||||
if conversation_command == ConversationCommand.Help:
|
if conversation_command == ConversationCommand.Help:
|
||||||
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
|
model_type = "offline" if await ConversationAdapters.has_offline_chat(user) else "openai"
|
||||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
|
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
|
||||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
llm_response = generate_chat_response(
|
llm_response = await agenerate_chat_response(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
meta_log=state.processor_config.conversation.meta_log,
|
meta_log,
|
||||||
compiled_references=compiled_references,
|
compiled_references,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries,
|
||||||
conversation_command=conversation_command,
|
conversation_command,
|
||||||
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm_response is None:
|
if llm_response is None:
|
||||||
|
@ -681,13 +658,14 @@ async def chat(
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
|
iterator = AsyncIteratorWrapper(llm_response)
|
||||||
|
|
||||||
# Get the full response from the generator if the stream is not requested.
|
# Get the full response from the generator if the stream is not requested.
|
||||||
aggregated_gpt_response = ""
|
aggregated_gpt_response = ""
|
||||||
while True:
|
async for item in iterator:
|
||||||
try:
|
if item is None:
|
||||||
aggregated_gpt_response += next(llm_response)
|
|
||||||
except StopIteration:
|
|
||||||
break
|
break
|
||||||
|
aggregated_gpt_response += item
|
||||||
|
|
||||||
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
||||||
|
|
||||||
|
@ -708,44 +686,53 @@ async def chat(
|
||||||
|
|
||||||
async def extract_references_and_questions(
|
async def extract_references_and_questions(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
meta_log: dict,
|
||||||
q: str,
|
q: str,
|
||||||
n: int,
|
n: int,
|
||||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
# Load Conversation History
|
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
compiled_references: List[Any] = []
|
compiled_references: List[Any] = []
|
||||||
inferred_queries: List[str] = []
|
inferred_queries: List[str] = []
|
||||||
|
|
||||||
if not EmbeddingsAdapters.user_has_embeddings(user=user):
|
if conversation_type == ConversationCommand.General:
|
||||||
|
return compiled_references, inferred_queries, q
|
||||||
|
|
||||||
|
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||||
)
|
)
|
||||||
return compiled_references, inferred_queries, q
|
return compiled_references, inferred_queries, q
|
||||||
|
|
||||||
if conversation_type == ConversationCommand.General:
|
|
||||||
return compiled_references, inferred_queries, q
|
|
||||||
|
|
||||||
# Extract filter terms from user message
|
# Extract filter terms from user message
|
||||||
defiltered_query = q
|
defiltered_query = q
|
||||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
defiltered_query = filter.defilter(defiltered_query)
|
defiltered_query = filter.defilter(defiltered_query)
|
||||||
filters_in_query = q.replace(defiltered_query, "").strip()
|
filters_in_query = q.replace(defiltered_query, "").strip()
|
||||||
|
|
||||||
|
using_offline_chat = False
|
||||||
|
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
if await ConversationAdapters.has_offline_chat(user):
|
||||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
using_offline_chat = True
|
||||||
|
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||||
|
chat_model = offline_chat.chat_model
|
||||||
|
if state.gpt4all_processor_config is None:
|
||||||
|
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
||||||
|
|
||||||
|
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||||
|
|
||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
)
|
)
|
||||||
elif state.processor_config.conversation.openai_model:
|
elif await ConversationAdapters.has_openai_chat(user):
|
||||||
api_key = state.processor_config.conversation.openai_model.api_key
|
openai_chat = await ConversationAdapters.get_openai_chat(user)
|
||||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
api_key = openai_chat.api_key
|
||||||
|
chat_model = openai_chat.chat_model
|
||||||
inferred_queries = extract_questions(
|
inferred_queries = extract_questions(
|
||||||
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
||||||
)
|
)
|
||||||
|
@ -754,7 +741,7 @@ async def extract_references_and_questions(
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
result_list = []
|
result_list = []
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
|
n_items = min(n, 3) if using_offline_chat else n
|
||||||
result_list.extend(
|
result_list.extend(
|
||||||
await search(
|
await search(
|
||||||
f"{query} {filters_in_query}",
|
f"{query} {filters_in_query}",
|
||||||
|
@ -765,6 +752,8 @@ async def extract_references_and_questions(
|
||||||
dedupe=False,
|
dedupe=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# Dedupe the results again, as duplicates may be returned across queries.
|
||||||
|
result_list = text_search.deduplicated_search_responses(result_list)
|
||||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||||
|
|
||||||
return compiled_references, inferred_queries, defiltered_query
|
return compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
|
@ -1,34 +1,50 @@
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Iterator, List, Optional, Union
|
from typing import Iterator, List, Optional, Union
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
from khoj.processor.conversation.openai.gpt import converse
|
from khoj.processor.conversation.openai.gpt import converse
|
||||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||||
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser
|
||||||
|
from database.adapters import ConversationAdapters
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
def perform_chat_checks():
|
|
||||||
if (
|
def perform_chat_checks(user: KhojUser):
|
||||||
state.processor_config
|
if ConversationAdapters.has_valid_offline_conversation_config(
|
||||||
and state.processor_config.conversation
|
user
|
||||||
and (
|
) or ConversationAdapters.has_valid_openai_conversation_config(user):
|
||||||
state.processor_config.conversation.openai_model
|
|
||||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||||
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
|
|
||||||
)
|
|
||||||
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
|
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
|
||||||
|
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
|
||||||
|
|
||||||
|
if has_offline_config:
|
||||||
|
offline_chat = await ConversationAdapters.get_offline_chat(user)
|
||||||
|
chat_model = offline_chat.chat_model
|
||||||
|
if state.gpt4all_processor_config is None:
|
||||||
|
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
|
||||||
|
return True
|
||||||
|
|
||||||
|
ready = has_openai_config or has_offline_config
|
||||||
|
|
||||||
|
if not ready:
|
||||||
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||||
|
|
||||||
|
|
||||||
def update_telemetry_state(
|
def update_telemetry_state(
|
||||||
|
@ -74,12 +90,22 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
||||||
return ConversationCommand.Default
|
return ConversationCommand.Default
|
||||||
|
|
||||||
|
|
||||||
|
async def construct_conversation_logs(user: KhojUser):
|
||||||
|
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_chat_response(*args):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_response(
|
def generate_chat_response(
|
||||||
q: str,
|
q: str,
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
compiled_references: List[str] = [],
|
compiled_references: List[str] = [],
|
||||||
inferred_queries: List[str] = [],
|
inferred_queries: List[str] = [],
|
||||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||||
|
user: KhojUser = None,
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
def _save_to_conversation_log(
|
def _save_to_conversation_log(
|
||||||
q: str,
|
q: str,
|
||||||
|
@ -89,17 +115,14 @@ def generate_chat_response(
|
||||||
inferred_queries: List[str],
|
inferred_queries: List[str],
|
||||||
meta_log,
|
meta_log,
|
||||||
):
|
):
|
||||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
|
updated_conversation = message_to_log(
|
||||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
|
||||||
user_message=q,
|
user_message=q,
|
||||||
chat_response=chat_response,
|
chat_response=chat_response,
|
||||||
user_message_metadata={"created": user_message_time},
|
user_message_metadata={"created": user_message_time},
|
||||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
)
|
)
|
||||||
|
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
||||||
# Load Conversation History
|
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
@ -116,8 +139,14 @@ def generate_chat_response(
|
||||||
meta_log=meta_log,
|
meta_log=meta_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user=user)
|
||||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||||
|
openai_chat_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||||
|
if offline_chat_config:
|
||||||
|
if state.gpt4all_processor_config.loaded_model is None:
|
||||||
|
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
|
||||||
|
|
||||||
|
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
user_query=q,
|
user_query=q,
|
||||||
|
@ -125,14 +154,14 @@ def generate_chat_response(
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_command=conversation_command,
|
conversation_command=conversation_command,
|
||||||
model=state.processor_config.conversation.offline_chat.chat_model,
|
model=offline_chat_config.chat_model,
|
||||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
tokenizer_name=conversation_config.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif state.processor_config.conversation.openai_model:
|
elif openai_chat_config:
|
||||||
api_key = state.processor_config.conversation.openai_model.api_key
|
api_key = openai_chat_config.api_key
|
||||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
chat_model = openai_chat_config.chat_model
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
q,
|
||||||
|
@ -141,8 +170,8 @@ def generate_chat_response(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
conversation_command=conversation_command,
|
conversation_command=conversation_command,
|
||||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
|
||||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
tokenizer_name=conversation_config.tokenizer if conversation_config else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -92,7 +92,7 @@ async def update(
|
||||||
|
|
||||||
if dict_to_update is not None:
|
if dict_to_update is not None:
|
||||||
dict_to_update[file.filename] = (
|
dict_to_update[file.filename] = (
|
||||||
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read()
|
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() # type: ignore
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
||||||
|
@ -181,24 +181,25 @@ def configure_content(
|
||||||
files: Optional[dict[str, dict[str, str]]],
|
files: Optional[dict[str, dict[str, str]]],
|
||||||
search_models: SearchModels,
|
search_models: SearchModels,
|
||||||
regenerate: bool = False,
|
regenerate: bool = False,
|
||||||
t: Optional[Union[state.SearchType, str]] = None,
|
t: Optional[state.SearchType] = None,
|
||||||
full_corpus: bool = True,
|
full_corpus: bool = True,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> Optional[ContentIndex]:
|
) -> Optional[ContentIndex]:
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
|
|
||||||
if t in [type.value for type in state.SearchType]:
|
if t is not None and not t.value in [type.value for type in state.SearchType]:
|
||||||
t = state.SearchType(t).value
|
logger.warning(f"🚨 Invalid search type: {t}")
|
||||||
|
return None
|
||||||
|
|
||||||
assert type(t) == str or t == None, f"Invalid search type: {t}"
|
search_type = t.value if t else None
|
||||||
|
|
||||||
if files is None:
|
if files is None:
|
||||||
logger.warning(f"🚨 No files to process for {t} search.")
|
logger.warning(f"🚨 No files to process for {search_type} search.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Org Notes Search
|
# Initialize Org Notes Search
|
||||||
if (t == None or t == state.SearchType.Org.value) and files["org"]:
|
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
|
||||||
logger.info("🦄 Setting up search for orgmode notes")
|
logger.info("🦄 Setting up search for orgmode notes")
|
||||||
# Extract Entries, Generate Notes Embeddings
|
# Extract Entries, Generate Notes Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -213,7 +214,7 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Markdown Search
|
# Initialize Markdown Search
|
||||||
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
|
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
|
||||||
logger.info("💎 Setting up search for markdown notes")
|
logger.info("💎 Setting up search for markdown notes")
|
||||||
# Extract Entries, Generate Markdown Embeddings
|
# Extract Entries, Generate Markdown Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -229,7 +230,7 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize PDF Search
|
# Initialize PDF Search
|
||||||
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
|
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
|
||||||
logger.info("🖨️ Setting up search for pdf")
|
logger.info("🖨️ Setting up search for pdf")
|
||||||
# Extract Entries, Generate PDF Embeddings
|
# Extract Entries, Generate PDF Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -245,7 +246,7 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize Plaintext Search
|
# Initialize Plaintext Search
|
||||||
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
|
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
|
||||||
logger.info("📄 Setting up search for plaintext")
|
logger.info("📄 Setting up search for plaintext")
|
||||||
# Extract Entries, Generate Plaintext Embeddings
|
# Extract Entries, Generate Plaintext Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -262,7 +263,7 @@ def configure_content(
|
||||||
try:
|
try:
|
||||||
# Initialize Image Search
|
# Initialize Image Search
|
||||||
if (
|
if (
|
||||||
(t == None or t == state.SearchType.Image.value)
|
(search_type == None or search_type == state.SearchType.Image.value)
|
||||||
and content_config
|
and content_config
|
||||||
and content_config.image
|
and content_config.image
|
||||||
and search_models.image_search
|
and search_models.image_search
|
||||||
|
@ -278,7 +279,7 @@ def configure_content(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
|
||||||
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
|
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
|
||||||
logger.info("🐙 Setting up search for github")
|
logger.info("🐙 Setting up search for github")
|
||||||
# Extract Entries, Generate Github Embeddings
|
# Extract Entries, Generate Github Embeddings
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
|
@ -296,7 +297,7 @@ def configure_content(
|
||||||
try:
|
try:
|
||||||
# Initialize Notion Search
|
# Initialize Notion Search
|
||||||
notion_config = NotionConfig.objects.filter(user=user).first()
|
notion_config = NotionConfig.objects.filter(user=user).first()
|
||||||
if (t == None or t in state.SearchType.Notion.value) and notion_config:
|
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
|
||||||
logger.info("🔌 Setting up search for notion")
|
logger.info("🔌 Setting up search for notion")
|
||||||
text_search.setup(
|
text_search.setup(
|
||||||
NotionToJsonl,
|
NotionToJsonl,
|
||||||
|
|
|
@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
|
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
|
||||||
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ if not state.demo:
|
||||||
@web_client.get("/config", response_class=HTMLResponse)
|
@web_client.get("/config", response_class=HTMLResponse)
|
||||||
@requires(["authenticated"], redirect="login_page")
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def config_page(request: Request):
|
def config_page(request: Request):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object
|
||||||
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
|
||||||
default_full_config = FullConfig(
|
default_full_config = FullConfig(
|
||||||
content_type=None,
|
content_type=None,
|
||||||
|
@ -100,9 +100,6 @@ if not state.demo:
|
||||||
"github": ("github" in enabled_content),
|
"github": ("github" in enabled_content),
|
||||||
"notion": ("notion" in enabled_content),
|
"notion": ("notion" in enabled_content),
|
||||||
"plaintext": ("plaintext" in enabled_content),
|
"plaintext": ("plaintext" in enabled_content),
|
||||||
"enable_offline_model": False,
|
|
||||||
"conversation_openai": False,
|
|
||||||
"conversation_gpt4all": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.content_index:
|
if state.content_index:
|
||||||
|
@ -112,11 +109,15 @@ if not state.demo:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.processor_config and state.processor_config.conversation:
|
enabled_chat_config = ConversationAdapters.get_enabled_conversation_settings(user)
|
||||||
|
|
||||||
successfully_configured.update(
|
successfully_configured.update(
|
||||||
{
|
{
|
||||||
"conversation_openai": state.processor_config.conversation.openai_model is not None,
|
"conversation_openai": enabled_chat_config["openai"],
|
||||||
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
|
"enable_offline_model": enabled_chat_config["offline_chat"],
|
||||||
|
"conversation_gpt4all": state.gpt4all_processor_config.loaded_model is not None
|
||||||
|
if state.gpt4all_processor_config
|
||||||
|
else False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -127,6 +128,7 @@ if not state.demo:
|
||||||
"current_config": current_config,
|
"current_config": current_config,
|
||||||
"current_model_state": successfully_configured,
|
"current_model_state": successfully_configured,
|
||||||
"anonymous_mode": state.anonymous_mode,
|
"anonymous_mode": state.anonymous_mode,
|
||||||
|
"username": user.username if user else None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -204,22 +206,22 @@ if not state.demo:
|
||||||
)
|
)
|
||||||
|
|
||||||
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
|
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
|
||||||
|
@requires(["authenticated"], redirect="login_page")
|
||||||
def conversation_processor_config_page(request: Request):
|
def conversation_processor_config_page(request: Request):
|
||||||
default_copy = constants.default_config.copy()
|
user = request.user.object
|
||||||
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
|
openai_config = ConversationAdapters.get_openai_conversation_config(user)
|
||||||
default_openai_config = OpenAIProcessorConfig(
|
|
||||||
|
if openai_config:
|
||||||
|
current_processor_openai_config = OpenAIProcessorConfig(
|
||||||
|
api_key=openai_config.api_key,
|
||||||
|
chat_model=openai_config.chat_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_processor_openai_config = OpenAIProcessorConfig(
|
||||||
api_key="",
|
api_key="",
|
||||||
chat_model=default_processor_config["chat-model"],
|
chat_model="gpt-3.5-turbo",
|
||||||
)
|
)
|
||||||
|
|
||||||
current_processor_openai_config = (
|
|
||||||
state.config.processor.conversation.openai
|
|
||||||
if state.config
|
|
||||||
and state.config.processor
|
|
||||||
and state.config.processor.conversation
|
|
||||||
and state.config.processor.conversation.openai
|
|
||||||
else default_openai_config
|
|
||||||
)
|
|
||||||
current_processor_openai_config = json.loads(current_processor_openai_config.json())
|
current_processor_openai_config = json.loads(current_processor_openai_config.json())
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
|
|
@ -236,6 +236,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
|
||||||
"image_score": f"{hit['image_score']:.9f}",
|
"image_score": f"{hit['image_score']:.9f}",
|
||||||
"metadata_score": f"{hit['metadata_score']:.9f}",
|
"metadata_score": f"{hit['metadata_score']:.9f}",
|
||||||
},
|
},
|
||||||
|
"corpus_id": hit["corpus_id"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
|
@ -14,10 +14,9 @@ from asgiref.sync import sync_to_async
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
|
||||||
from khoj.utils.config import TextSearchModel
|
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
|
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||||
from khoj.utils.jsonl import load_jsonl
|
from khoj.utils.jsonl import load_jsonl
|
||||||
from khoj.processor.text_to_jsonl import TextEmbeddings
|
from khoj.processor.text_to_jsonl import TextEmbeddings
|
||||||
from database.adapters import EmbeddingsAdapters
|
from database.adapters import EmbeddingsAdapters
|
||||||
|
@ -36,36 +35,6 @@ search_type_to_embeddings_type = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def initialize_model(search_config: TextSearchConfig):
|
|
||||||
"Initialize model for semantic search on text"
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
|
|
||||||
# If model directory is configured
|
|
||||||
if search_config.model_directory:
|
|
||||||
# Convert model directory to absolute path
|
|
||||||
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
|
|
||||||
# Create model directory if it doesn't exist
|
|
||||||
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# The bi-encoder encodes all entries to use for semantic search
|
|
||||||
bi_encoder = load_model(
|
|
||||||
model_dir=search_config.model_directory,
|
|
||||||
model_name=search_config.encoder,
|
|
||||||
model_type=search_config.encoder_type or SentenceTransformer,
|
|
||||||
device=f"{state.device}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# The cross-encoder re-ranks the results to improve quality
|
|
||||||
cross_encoder = load_model(
|
|
||||||
model_dir=search_config.model_directory,
|
|
||||||
model_name=search_config.cross_encoder,
|
|
||||||
model_type=CrossEncoder,
|
|
||||||
device=f"{state.device}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return TextSearchModel(bi_encoder, cross_encoder)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_entries(jsonl_file) -> List[Entry]:
|
def extract_entries(jsonl_file) -> List[Entry]:
|
||||||
"Load entries from compressed jsonl"
|
"Load entries from compressed jsonl"
|
||||||
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
|
||||||
|
@ -176,6 +145,7 @@ def collate_results(hits, dedupe=True):
|
||||||
{
|
{
|
||||||
"entry": hit.raw,
|
"entry": hit.raw,
|
||||||
"score": hit.distance,
|
"score": hit.distance,
|
||||||
|
"corpus_id": str(hit.corpus_id),
|
||||||
"additional": {
|
"additional": {
|
||||||
"file": hit.file_path,
|
"file": hit.file_path,
|
||||||
"compiled": hit.compiled,
|
"compiled": hit.compiled,
|
||||||
|
@ -185,6 +155,28 @@ def collate_results(hits, dedupe=True):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||||
|
hit_ids = set()
|
||||||
|
for hit in hits:
|
||||||
|
if hit.corpus_id in hit_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
hit_ids.add(hit.corpus_id)
|
||||||
|
yield SearchResponse.parse_obj(
|
||||||
|
{
|
||||||
|
"entry": hit.entry,
|
||||||
|
"score": hit.score,
|
||||||
|
"corpus_id": hit.corpus_id,
|
||||||
|
"additional": {
|
||||||
|
"file": hit.additional["file"],
|
||||||
|
"compiled": hit.additional["compiled"],
|
||||||
|
"heading": hit.additional["heading"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def rerank_and_sort_results(hits, query):
|
def rerank_and_sort_results(hits, query):
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
hits = cross_encoder_score(query, hits)
|
hits = cross_encoder_score(query, hits)
|
||||||
|
|
|
@ -5,8 +5,7 @@ from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from typing import TYPE_CHECKING, List, Optional, Union, Any
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
|
||||||
from khoj.processor.conversation.gpt4all.utils import download_model
|
from khoj.processor.conversation.gpt4all.utils import download_model
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
|
@ -19,9 +18,7 @@ logger = logging.getLogger(__name__)
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
|
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
class SearchType(str, Enum):
|
||||||
|
@ -79,31 +76,15 @@ class GPT4AllProcessorConfig:
|
||||||
loaded_model: Union[Any, None] = None
|
loaded_model: Union[Any, None] = None
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfigModel:
|
class GPT4AllProcessorModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
conversation_config: ConversationProcessorConfig,
|
chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||||
):
|
):
|
||||||
self.openai_model = conversation_config.openai
|
self.chat_model = chat_model
|
||||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
self.loaded_model = None
|
||||||
self.offline_chat = conversation_config.offline_chat or OfflineChatProcessorConfig()
|
|
||||||
self.max_prompt_size = conversation_config.max_prompt_size
|
|
||||||
self.tokenizer = conversation_config.tokenizer
|
|
||||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
|
||||||
self.chat_session: List[str] = []
|
|
||||||
self.meta_log: dict = {}
|
|
||||||
|
|
||||||
if self.offline_chat.enable_offline_chat:
|
|
||||||
try:
|
try:
|
||||||
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
|
self.loaded_model = download_model(self.chat_model)
|
||||||
except Exception as e:
|
except ValueError as e:
|
||||||
self.offline_chat.enable_offline_chat = False
|
self.loaded_model = None
|
||||||
self.gpt4all_model.loaded_model = None
|
|
||||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||||
else:
|
|
||||||
self.gpt4all_model.loaded_model = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProcessorConfigModel:
|
|
||||||
conversation: Union[ConversationProcessorConfigModel, None] = None
|
|
||||||
|
|
|
@ -8,136 +8,14 @@ telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||||
content_directory = "~/.khoj/content/"
|
content_directory = "~/.khoj/content/"
|
||||||
|
|
||||||
empty_config = {
|
empty_config = {
|
||||||
"content-type": {
|
|
||||||
"org": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
|
|
||||||
"index-heading-entries": False,
|
|
||||||
},
|
|
||||||
"markdown": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
|
|
||||||
},
|
|
||||||
"pdf": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
|
|
||||||
},
|
|
||||||
"plaintext": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"search-type": {
|
"search-type": {
|
||||||
"symmetric": {
|
|
||||||
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
"model_directory": "~/.khoj/search/symmetric/",
|
|
||||||
},
|
|
||||||
"asymmetric": {
|
|
||||||
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
|
||||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
"model_directory": "~/.khoj/search/asymmetric/",
|
|
||||||
},
|
|
||||||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||||
},
|
},
|
||||||
"processor": {
|
|
||||||
"conversation": {
|
|
||||||
"openai": {
|
|
||||||
"api-key": None,
|
|
||||||
"chat-model": "gpt-3.5-turbo",
|
|
||||||
},
|
|
||||||
"offline-chat": {
|
|
||||||
"enable-offline-chat": False,
|
|
||||||
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
|
||||||
},
|
|
||||||
"tokenizer": None,
|
|
||||||
"max-prompt-size": None,
|
|
||||||
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# default app config to use
|
# default app config to use
|
||||||
default_config = {
|
default_config = {
|
||||||
"content-type": {
|
|
||||||
"org": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
|
|
||||||
"index-heading-entries": False,
|
|
||||||
},
|
|
||||||
"markdown": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
|
|
||||||
},
|
|
||||||
"pdf": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
|
|
||||||
},
|
|
||||||
"image": {
|
|
||||||
"input-directories": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
|
|
||||||
"batch-size": 50,
|
|
||||||
"use-xmp-metadata": False,
|
|
||||||
},
|
|
||||||
"github": {
|
|
||||||
"pat-token": None,
|
|
||||||
"repos": [],
|
|
||||||
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
|
|
||||||
},
|
|
||||||
"notion": {
|
|
||||||
"token": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt",
|
|
||||||
},
|
|
||||||
"plaintext": {
|
|
||||||
"input-files": None,
|
|
||||||
"input-filter": None,
|
|
||||||
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
|
|
||||||
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"search-type": {
|
"search-type": {
|
||||||
"symmetric": {
|
|
||||||
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
"model_directory": "~/.khoj/search/symmetric/",
|
|
||||||
},
|
|
||||||
"asymmetric": {
|
|
||||||
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
|
||||||
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
"model_directory": "~/.khoj/search/asymmetric/",
|
|
||||||
},
|
|
||||||
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
|
||||||
},
|
},
|
||||||
"processor": {
|
|
||||||
"conversation": {
|
|
||||||
"openai": {
|
|
||||||
"api-key": None,
|
|
||||||
"chat-model": "gpt-3.5-turbo",
|
|
||||||
},
|
|
||||||
"offline-chat": {
|
|
||||||
"enable-offline-chat": False,
|
|
||||||
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
|
||||||
},
|
|
||||||
"tokenizer": None,
|
|
||||||
"max-prompt-size": None,
|
|
||||||
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ from time import perf_counter
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Union, TYPE_CHECKING
|
from typing import Optional, Union, TYPE_CHECKING
|
||||||
import uuid
|
import uuid
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants
|
from khoj.utils import constants
|
||||||
|
@ -29,6 +30,28 @@ if TYPE_CHECKING:
|
||||||
from khoj.utils.rawconfig import AppConfig
|
from khoj.utils.rawconfig import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncIteratorWrapper:
|
||||||
|
def __init__(self, obj):
|
||||||
|
self._it = iter(obj)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
value = await self.next_async()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
return
|
||||||
|
return value
|
||||||
|
|
||||||
|
@sync_to_async
|
||||||
|
def next_async(self):
|
||||||
|
try:
|
||||||
|
return next(self._it)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
def is_none_or_empty(item):
|
def is_none_or_empty(item):
|
||||||
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""
|
||||||
|
|
||||||
|
|
|
@ -67,13 +67,6 @@ class ContentConfig(ConfigBase):
|
||||||
notion: Optional[NotionContentConfig]
|
notion: Optional[NotionContentConfig]
|
||||||
|
|
||||||
|
|
||||||
class TextSearchConfig(ConfigBase):
|
|
||||||
encoder: str
|
|
||||||
cross_encoder: str
|
|
||||||
encoder_type: Optional[str]
|
|
||||||
model_directory: Optional[Path]
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSearchConfig(ConfigBase):
|
class ImageSearchConfig(ConfigBase):
|
||||||
encoder: str
|
encoder: str
|
||||||
encoder_type: Optional[str]
|
encoder_type: Optional[str]
|
||||||
|
@ -81,8 +74,6 @@ class ImageSearchConfig(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class SearchConfig(ConfigBase):
|
class SearchConfig(ConfigBase):
|
||||||
asymmetric: Optional[TextSearchConfig]
|
|
||||||
symmetric: Optional[TextSearchConfig]
|
|
||||||
image: Optional[ImageSearchConfig]
|
image: Optional[ImageSearchConfig]
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,11 +88,10 @@ class OfflineChatProcessorConfig(ConfigBase):
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(ConfigBase):
|
class ConversationProcessorConfig(ConfigBase):
|
||||||
conversation_logfile: Path
|
openai: Optional[OpenAIProcessorConfig] = None
|
||||||
openai: Optional[OpenAIProcessorConfig]
|
offline_chat: Optional[OfflineChatProcessorConfig] = None
|
||||||
offline_chat: Optional[OfflineChatProcessorConfig]
|
max_prompt_size: Optional[int] = None
|
||||||
max_prompt_size: Optional[int]
|
tokenizer: Optional[str] = None
|
||||||
tokenizer: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessorConfig(ConfigBase):
|
class ProcessorConfig(ConfigBase):
|
||||||
|
@ -125,6 +115,7 @@ class SearchResponse(ConfigBase):
|
||||||
score: float
|
score: float
|
||||||
cross_score: Optional[float]
|
cross_score: Optional[float]
|
||||||
additional: Optional[dict]
|
additional: Optional[dict]
|
||||||
|
corpus_id: str
|
||||||
|
|
||||||
|
|
||||||
class Entry:
|
class Entry:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import config as utils_config
|
from khoj.utils import config as utils_config
|
||||||
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
|
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import LRU
|
from khoj.utils.helpers import LRU
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
||||||
|
@ -21,7 +21,7 @@ search_models = SearchModels()
|
||||||
embeddings_model = EmbeddingsModel()
|
embeddings_model = EmbeddingsModel()
|
||||||
cross_encoder_model = CrossEncoderModel()
|
cross_encoder_model = CrossEncoderModel()
|
||||||
content_index = ContentIndex()
|
content_index = ContentIndex()
|
||||||
processor_config = ProcessorConfigModel()
|
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||||
config_file: Path = None
|
config_file: Path = None
|
||||||
verbose: int = 0
|
verbose: int = 0
|
||||||
host: str = None
|
host: str = None
|
||||||
|
|
|
@ -5,7 +5,6 @@ from pathlib import Path
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
import factory
|
|
||||||
import os
|
import os
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
@ -13,7 +12,7 @@ app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_routes, configure_search_types, configure_middleware
|
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||||
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
from khoj.processor.plaintext.plaintext_to_jsonl import PlaintextToJsonl
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
|
@ -21,13 +20,8 @@ from khoj.utils.constants import web_directory
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
ConversationProcessorConfig,
|
|
||||||
OfflineChatProcessorConfig,
|
|
||||||
OpenAIProcessorConfig,
|
|
||||||
ProcessorConfig,
|
|
||||||
ImageContentConfig,
|
ImageContentConfig,
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
TextSearchConfig,
|
|
||||||
ImageSearchConfig,
|
ImageSearchConfig,
|
||||||
)
|
)
|
||||||
from khoj.utils import state, fs_syncer
|
from khoj.utils import state, fs_syncer
|
||||||
|
@ -42,42 +36,25 @@ from database.models import (
|
||||||
GithubRepoConfig,
|
GithubRepoConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from tests.helpers import (
|
||||||
|
UserFactory,
|
||||||
|
ConversationProcessorConfigFactory,
|
||||||
|
OpenAIProcessorConversationConfigFactory,
|
||||||
|
OfflineChatProcessorConversationConfigFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def enable_db_access_for_all_tests(db):
|
def enable_db_access_for_all_tests(db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UserFactory(factory.django.DjangoModelFactory):
|
|
||||||
class Meta:
|
|
||||||
model = KhojUser
|
|
||||||
|
|
||||||
username = factory.Faker("name")
|
|
||||||
email = factory.Faker("email")
|
|
||||||
password = factory.Faker("password")
|
|
||||||
uuid = factory.Faker("uuid4")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_config() -> SearchConfig:
|
def search_config() -> SearchConfig:
|
||||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||||
model_dir.mkdir(parents=True, exist_ok=True)
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
search_config = SearchConfig()
|
search_config = SearchConfig()
|
||||||
|
|
||||||
search_config.symmetric = TextSearchConfig(
|
|
||||||
encoder="sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
model_directory=model_dir / "symmetric/",
|
|
||||||
encoder_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
search_config.asymmetric = TextSearchConfig(
|
|
||||||
encoder="sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
|
|
||||||
cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
model_directory=model_dir / "asymmetric/",
|
|
||||||
encoder_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
search_config.image = ImageSearchConfig(
|
search_config.image = ImageSearchConfig(
|
||||||
encoder="sentence-transformers/clip-ViT-B-32",
|
encoder="sentence-transformers/clip-ViT-B-32",
|
||||||
model_directory=model_dir / "image/",
|
model_directory=model_dir / "image/",
|
||||||
|
@ -177,55 +154,48 @@ def md_content_config():
|
||||||
return markdown_config
|
return markdown_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="function")
|
||||||
def processor_config(tmp_path_factory):
|
def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
processor_dir = tmp_path_factory.mktemp("processor")
|
|
||||||
|
|
||||||
# The conversation processor is the only configured processor
|
|
||||||
# It needs an OpenAI API key to work.
|
|
||||||
if not openai_api_key:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Setup conversation processor, if OpenAI API key is set
|
|
||||||
processor_config = ProcessorConfig()
|
|
||||||
processor_config.conversation = ConversationProcessorConfig(
|
|
||||||
openai=OpenAIProcessorConfig(api_key=openai_api_key),
|
|
||||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return processor_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def processor_config_offline_chat(tmp_path_factory):
|
|
||||||
processor_dir = tmp_path_factory.mktemp("processor")
|
|
||||||
|
|
||||||
# Setup conversation processor
|
|
||||||
processor_config = ProcessorConfig()
|
|
||||||
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
|
|
||||||
processor_config.conversation = ConversationProcessorConfig(
|
|
||||||
offline_chat=offline_chat,
|
|
||||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
|
||||||
)
|
|
||||||
|
|
||||||
return processor_config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def chat_client(md_content_config: ContentConfig, search_config: SearchConfig, processor_config: ProcessorConfig):
|
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
|
LocalMarkdownConfig.objects.create(
|
||||||
|
input_files=None,
|
||||||
|
input_filter=["tests/data/markdown/*.markdown"],
|
||||||
|
user=default_user2,
|
||||||
|
)
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
all_files = fs_syncer.collect_files()
|
all_files = fs_syncer.collect_files(user=default_user2)
|
||||||
state.content_index = configure_content(
|
state.content_index = configure_content(
|
||||||
state.content_index, state.config.content_type, all_files, state.search_models
|
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(processor_config)
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
|
OpenAIProcessorConversationConfigFactory(user=default_user2)
|
||||||
|
|
||||||
|
state.anonymous_mode = True
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
configure_routes(app)
|
||||||
|
configure_middleware(app)
|
||||||
|
app.mount("/static", StaticFiles(directory=web_directory), name="static")
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
|
# Initialize app state
|
||||||
|
state.config.search_type = search_config
|
||||||
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
|
# Initialize Processor from Config
|
||||||
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
|
OpenAIProcessorConversationConfigFactory(user=default_user2)
|
||||||
|
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = True
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
@ -249,7 +219,6 @@ def fastapi_app():
|
||||||
def client(
|
def client(
|
||||||
content_config: ContentConfig,
|
content_config: ContentConfig,
|
||||||
search_config: SearchConfig,
|
search_config: SearchConfig,
|
||||||
processor_config: ProcessorConfig,
|
|
||||||
default_user: KhojUser,
|
default_user: KhojUser,
|
||||||
):
|
):
|
||||||
state.config.content_type = content_config
|
state.config.content_type = content_config
|
||||||
|
@ -274,7 +243,7 @@ def client(
|
||||||
user=default_user,
|
user=default_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
state.processor_config = configure_processor(processor_config)
|
ConversationProcessorConfigFactory(user=default_user)
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = True
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
@ -286,25 +255,32 @@ def client(
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def client_offline_chat(
|
def client_offline_chat(
|
||||||
search_config: SearchConfig,
|
search_config: SearchConfig,
|
||||||
processor_config_offline_chat: ProcessorConfig,
|
|
||||||
content_config: ContentConfig,
|
content_config: ContentConfig,
|
||||||
md_content_config,
|
default_user2: KhojUser,
|
||||||
):
|
):
|
||||||
# Initialize app state
|
# Initialize app state
|
||||||
state.config.content_type = md_content_config
|
state.config.content_type = md_content_config
|
||||||
state.config.search_type = search_config
|
state.config.search_type = search_config
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
|
|
||||||
|
LocalMarkdownConfig.objects.create(
|
||||||
|
input_files=None,
|
||||||
|
input_filter=["tests/data/markdown/*.markdown"],
|
||||||
|
user=default_user2,
|
||||||
|
)
|
||||||
|
|
||||||
# Index Markdown Content for Search
|
# Index Markdown Content for Search
|
||||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||||
|
|
||||||
all_files = fs_syncer.collect_files(state.config.content_type)
|
all_files = fs_syncer.collect_files(user=default_user2)
|
||||||
state.content_index = configure_content(
|
configure_content(
|
||||||
state.content_index, state.config.content_type, all_files, state.search_models
|
state.content_index, state.config.content_type, all_files, state.search_models, user=default_user2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
state.processor_config = configure_processor(processor_config_offline_chat)
|
ConversationProcessorConfigFactory(user=default_user2)
|
||||||
|
OfflineChatProcessorConversationConfigFactory(user=default_user2)
|
||||||
|
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = True
|
||||||
|
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
|
51
tests/helpers.py
Normal file
51
tests/helpers.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
import factory
|
||||||
|
import os
|
||||||
|
|
||||||
|
from database.models import (
|
||||||
|
KhojUser,
|
||||||
|
ConversationProcessorConfig,
|
||||||
|
OfflineChatProcessorConversationConfig,
|
||||||
|
OpenAIProcessorConversationConfig,
|
||||||
|
Conversation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = KhojUser
|
||||||
|
|
||||||
|
username = factory.Faker("name")
|
||||||
|
email = factory.Faker("email")
|
||||||
|
password = factory.Faker("password")
|
||||||
|
uuid = factory.Faker("uuid4")
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = ConversationProcessorConfig
|
||||||
|
|
||||||
|
max_prompt_size = 2000
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
|
||||||
|
class OfflineChatProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = OfflineChatProcessorConversationConfig
|
||||||
|
|
||||||
|
enable_offline_chat = True
|
||||||
|
chat_model = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = OpenAIProcessorConversationConfig
|
||||||
|
|
||||||
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
chat_model = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = Conversation
|
||||||
|
|
||||||
|
user = factory.SubFactory(UserFactory)
|
|
@ -119,7 +119,12 @@ def test_get_configured_types_via_api(client, sample_org_data):
|
||||||
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
def test_get_api_config_types(client, search_config: SearchConfig, sample_org_data, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
text_search.setup(OrgToJsonl, sample_org_data, regenerate=False, user=default_user2)
|
||||||
|
|
||||||
|
# Act
|
||||||
response = client.get(f"/api/config/types")
|
response = client.get(f"/api/config/types")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
assert response.json() == ["all", "org", "image"]
|
assert response.json() == ["all", "org", "image"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,7 @@ from faker import Faker
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
from khoj.utils import state
|
from tests.helpers import ConversationFactory
|
||||||
|
|
||||||
|
|
||||||
SKIP_TESTS = True
|
SKIP_TESTS = True
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
|
@ -23,7 +22,7 @@ fake = Faker()
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def populate_chat_history(message_list):
|
def populate_chat_history(message_list, user):
|
||||||
# Generate conversation logs
|
# Generate conversation logs
|
||||||
conversation_log = {"chat": []}
|
conversation_log = {"chat": []}
|
||||||
for user_message, llm_message, context in message_list:
|
for user_message, llm_message, context in message_list:
|
||||||
|
@ -33,14 +32,15 @@ def populate_chat_history(message_list):
|
||||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Conversation Metadata Logs in Application State
|
# Update Conversation Metadata Logs in Database
|
||||||
state.processor_config.conversation.meta_log = conversation_log
|
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||||
|
|
||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||||
|
@ -56,13 +56,14 @@ def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_c
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_chat_history(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_from_chat_history(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
|
@ -78,7 +79,8 @@ def test_answer_from_chat_history(client_offline_chat):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_currently_retrieved_content(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_from_currently_retrieved_content(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
@ -88,7 +90,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||||
["Testatron was born on 1st April 1984 in Testville."],
|
["Testatron was born on 1st April 1984 in Testville."],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
|
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||||
|
@ -101,7 +103,8 @@ def test_answer_from_currently_retrieved_content(client_offline_chat):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_from_chat_history_and_previously_retrieved_content(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
@ -111,7 +114,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
||||||
["Testatron was born on 1st April 1984 in Testville."],
|
["Testatron was born on 1st April 1984 in Testville."],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||||
|
@ -130,13 +133,14 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
||||||
reason="Chat director not capable of answering this question yet because it requires extract_questions",
|
reason="Chat director not capable of answering this question yet because it requires extract_questions",
|
||||||
)
|
)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_from_chat_history_and_currently_retrieved_content(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||||
|
@ -154,14 +158,15 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||||
|
@ -177,11 +182,12 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_using_general_command(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_using_general_command(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||||
|
@ -194,11 +200,12 @@ def test_answer_using_general_command(client_offline_chat):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||||
|
@ -211,12 +218,13 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_using_file_filter(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_using_file_filter(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||||
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
|
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
|
||||||
|
@ -229,11 +237,12 @@ def test_answer_using_file_filter(client_offline_chat):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_not_known_using_notes_command(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_not_known_using_notes_command(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||||
|
@ -247,6 +256,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat):
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@freeze_time("2023-04-01")
|
@freeze_time("2023-04-01")
|
||||||
def test_answer_requires_current_date_awareness(client_offline_chat):
|
def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||||
|
@ -265,6 +275,7 @@ def test_answer_requires_current_date_awareness(client_offline_chat):
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@freeze_time("2023-04-01")
|
@freeze_time("2023-04-01")
|
||||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
def test_answer_requires_date_aware_aggregation_across_provided_notes(client_offline_chat):
|
||||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||||
|
@ -280,14 +291,15 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_general_question_not_in_chat_history_or_retrieved_content(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
("Where was I born?", "You were born Testville.", []),
|
("Where was I born?", "You were born Testville.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(
|
response = client_offline_chat.get(
|
||||||
|
@ -307,7 +319,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
@pytest.mark.xfail(reason="Chat director not consistently capable of asking for clarification yet.")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_chat, default_user2):
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
@ -328,14 +341,15 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
("Where was I born?", "You were born Testville.", []),
|
("Where was I born?", "You were born Testville.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
|
@ -350,11 +364,12 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_chat_history_very_long(client_offline_chat):
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||||
|
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
|
@ -368,6 +383,7 @@ def test_answer_chat_history_very_long(client_offline_chat):
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
def test_answer_requires_multiple_independent_searches(client_offline_chat):
|
||||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||||
# Act
|
# Act
|
||||||
|
|
|
@ -9,8 +9,8 @@ from khoj.processor.conversation import prompts
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
from khoj.utils import state
|
from tests.helpers import ConversationFactory
|
||||||
|
from database.models import KhojUser
|
||||||
|
|
||||||
# Initialize variables for tests
|
# Initialize variables for tests
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
|
@ -23,7 +23,7 @@ if api_key is None:
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def populate_chat_history(message_list):
|
def populate_chat_history(message_list, user=None):
|
||||||
# Generate conversation logs
|
# Generate conversation logs
|
||||||
conversation_log = {"chat": []}
|
conversation_log = {"chat": []}
|
||||||
for user_message, gpt_message, context in message_list:
|
for user_message, gpt_message, context in message_list:
|
||||||
|
@ -33,13 +33,14 @@ def populate_chat_history(message_list):
|
||||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update Conversation Metadata Logs in Application State
|
# Update Conversation Metadata Logs in Database
|
||||||
state.processor_config.conversation.meta_log = conversation_log
|
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||||
|
|
||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||||
|
@ -54,14 +55,15 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_chat_history(chat_client):
|
def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
|
@ -76,8 +78,9 @@ def test_answer_from_chat_history(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_currently_retrieved_content(chat_client):
|
def test_answer_from_currently_retrieved_content(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
@ -87,7 +90,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
||||||
["Testatron was born on 1st April 1984 in Testville."],
|
["Testatron was born on 1st April 1984 in Testville."],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||||
|
@ -99,8 +102,9 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_no_background, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
@ -110,10 +114,10 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
||||||
["Testatron was born on 1st April 1984 in Testville."],
|
["Testatron was born on 1st April 1984 in Testville."],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -125,14 +129,15 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||||
|
@ -148,15 +153,16 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
|
||||||
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
"Chat director should say don't know as not enough contexts in chat history or retrieved to answer question"
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||||
|
@ -171,12 +177,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_using_general_command(chat_client):
|
def test_answer_using_general_command(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||||
|
@ -188,12 +195,13 @@ def test_answer_using_general_command(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_from_retrieved_content_using_notes_command(chat_client):
|
def test_answer_from_retrieved_content_using_notes_command(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||||
|
@ -205,15 +213,16 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_not_known_using_notes_command(chat_client):
|
def test_answer_not_known_using_notes_command(chat_client_no_background, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||||
message_list = []
|
message_list = []
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -223,6 +232,7 @@ def test_answer_not_known_using_notes_command(chat_client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering time aware questions yet")
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
@freeze_time("2023-04-01")
|
@freeze_time("2023-04-01")
|
||||||
def test_answer_requires_current_date_awareness(chat_client):
|
def test_answer_requires_current_date_awareness(chat_client):
|
||||||
|
@ -240,11 +250,13 @@ def test_answer_requires_current_date_awareness(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
@freeze_time("2023-04-01")
|
@freeze_time("2023-04-01")
|
||||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
||||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||||
# Act
|
# Act
|
||||||
|
|
||||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
|
@ -254,15 +266,16 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client):
|
def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
("Where was I born?", "You were born Testville.", []),
|
("Where was I born?", "You were born Testville.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(
|
response = chat_client.get(
|
||||||
|
@ -280,10 +293,12 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
|
||||||
|
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -301,15 +316,16 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user2: KhojUser):
|
||||||
# Arrange
|
# Arrange
|
||||||
message_list = [
|
message_list = [
|
||||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
("Where was I born?", "You were born Testville.", []),
|
("Where was I born?", "You were born Testville.", []),
|
||||||
]
|
]
|
||||||
populate_chat_history(message_list)
|
populate_chat_history(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||||
|
@ -324,6 +340,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||||
|
@ -340,10 +357,12 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_answer_using_file_filter(chat_client):
|
def test_answer_using_file_filter(chat_client):
|
||||||
"Chat should be able to use search filters in the query"
|
"Chat should be able to use search filters in the query"
|
||||||
# Act
|
# Act
|
||||||
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
|
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
|
||||||
|
|
||||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,11 @@ from khoj.search_type import text_search
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
from khoj.processor.github.github_to_jsonl import GithubToJsonl
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.fs_syncer import collect_files, get_org_files
|
||||||
from khoj.utils.fs_syncer import get_org_files, collect_files
|
|
||||||
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
from database.models import LocalOrgConfig, KhojUser, Embeddings, GithubConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig, TextContentConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
|
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
|
|
Loading…
Reference in a new issue