diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 128488ac..909c304f 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -1,5 +1,6 @@
# Standard Packages
import logging
+import json
from datetime import datetime, timedelta
from typing import Optional
@@ -31,6 +32,10 @@ def extract_questions(
"""
Infer search queries to retrieve relevant notes to answer user query
"""
+
+ def _valid_question(question: str):
+ return not is_none_or_empty(question) and question != "[]"
+
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
@@ -70,7 +75,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response
try:
- questions = (
+ split_questions = (
response.content.strip(empty_escape_sequences)
.replace("['", '["')
.replace("']", '"]')
@@ -79,9 +84,18 @@ def extract_questions(
.replace('"]', "")
.split('", "')
)
+ questions = []
+
+ for question in split_questions:
+ if question not in questions and _valid_question(question):
+ questions.append(question)
+
+ if is_none_or_empty(questions):
+ raise ValueError("GPT returned empty JSON")
except:
logger.warning(f"GPT returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
+
logger.debug(f"Extracted Questions by GPT: {questions}")
return questions
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index b44195fd..e2164719 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -154,17 +154,20 @@ def truncate_messages(
)
system_message = messages.pop()
+ assert type(system_message.content) == str
system_message_tokens = len(encoder.encode(system_message.content))
- tokens = sum([len(encoder.encode(message.content)) for message in messages])
+ tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
- tokens = sum([len(encoder.encode(message.content)) for message in messages])
+ assert type(system_message.content) == str
+ tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
- current_message = "\n".join(messages[0].content.split("\n")[:-1])
- original_question = "\n".join(messages[0].content.split("\n")[-1:])
+ assert type(system_message.content) == str
+ current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
+ original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index ae13ed70..d9b80756 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -31,6 +31,7 @@ from khoj.utils import state, constants
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
+ CommonQueryParams,
get_conversation_command,
validate_conversation_config,
agenerate_chat_response,
@@ -55,6 +56,7 @@ from database.models import (
Entry as DbEntry,
GithubConfig,
NotionConfig,
+ ChatModelOptions,
)
@@ -122,7 +124,7 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
def _initialize_config():
if state.config is None:
state.config = FullConfig()
- state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
+ state.config.search_type = SearchConfig.model_validate(constants.default_config["search-type"])
@api.get("/config/data", response_model=FullConfig)
@@ -355,15 +357,12 @@ def get_config_types(
async def search(
q: str,
request: Request,
+ common: CommonQueryParams,
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
- client: Optional[str] = None,
- user_agent: Optional[str] = Header(None),
- referer: Optional[str] = Header(None),
- host: Optional[str] = Header(None),
):
user = request.user.object
start_time = time.time()
@@ -467,10 +466,7 @@ async def search(
request=request,
telemetry_type="api",
api="search",
- client=client,
- user_agent=user_agent,
- referer=referer,
- host=host,
+ **common.__dict__,
)
end_time = time.time()
@@ -483,12 +479,9 @@ async def search(
@requires(["authenticated"])
def update(
request: Request,
+ common: CommonQueryParams,
t: Optional[SearchType] = None,
force: Optional[bool] = False,
- client: Optional[str] = None,
- user_agent: Optional[str] = Header(None),
- referer: Optional[str] = Header(None),
- host: Optional[str] = Header(None),
):
user = request.user.object
if not state.config:
@@ -514,10 +507,7 @@ def update(
request=request,
telemetry_type="api",
api="update",
- client=client,
- user_agent=user_agent,
- referer=referer,
- host=host,
+ **common.__dict__,
)
return {"status": "ok", "message": "khoj reloaded"}
@@ -527,10 +517,7 @@ def update(
@requires(["authenticated"])
def chat_history(
request: Request,
- client: Optional[str] = None,
- user_agent: Optional[str] = Header(None),
- referer: Optional[str] = Header(None),
- host: Optional[str] = Header(None),
+ common: CommonQueryParams,
):
user = request.user.object
validate_conversation_config()
@@ -542,10 +529,7 @@ def chat_history(
request=request,
telemetry_type="api",
api="chat",
- client=client,
- user_agent=user_agent,
- referer=referer,
- host=host,
+ **common.__dict__,
)
return {"status": "ok", "response": meta_log.get("chat", [])}
@@ -555,10 +539,7 @@ def chat_history(
@requires(["authenticated"])
async def chat_options(
request: Request,
- client: Optional[str] = None,
- user_agent: Optional[str] = Header(None),
- referer: Optional[str] = Header(None),
- host: Optional[str] = Header(None),
+ common: CommonQueryParams,
) -> Response:
cmd_options = {}
for cmd in ConversationCommand:
@@ -568,10 +549,7 @@ async def chat_options(
request=request,
telemetry_type="api",
api="chat_options",
- client=client,
- user_agent=user_agent,
- referer=referer,
- host=host,
+ **common.__dict__,
)
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
@@ -580,14 +558,11 @@ async def chat_options(
@requires(["authenticated"])
async def chat(
request: Request,
+ common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.18,
- client: Optional[str] = None,
stream: Optional[bool] = False,
- user_agent: Optional[str] = Header(None),
- referer: Optional[str] = Header(None),
- host: Optional[str] = Header(None),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
) -> Response:
@@ -601,7 +576,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, meta_log, q, (n or 5), (d or math.inf), conversation_command
+ request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
online_results: Dict = dict()
@@ -647,11 +622,8 @@ async def chat(
request=request,
telemetry_type="api",
api="chat",
- client=client,
- user_agent=user_agent,
- referer=referer,
- host=host,
metadata=chat_metadata,
+ **common.__dict__,
)
if llm_response is None:
@@ -678,6 +650,7 @@ async def chat(
async def extract_references_and_questions(
request: Request,
+ common: CommonQueryParams,
meta_log: dict,
q: str,
n: int,
@@ -710,7 +683,16 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
- if await ConversationAdapters.ahas_offline_chat():
+ offline_chat_config = await ConversationAdapters.aget_offline_chat_conversation_config()
+ conversation_config = await ConversationAdapters.aget_conversation_config(user)
+ if conversation_config is None:
+ conversation_config = await ConversationAdapters.aget_default_conversation_config()
+ openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
+ if (
+ offline_chat_config
+ and offline_chat_config.enabled
+ and conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE
+ ):
using_offline_chat = True
offline_chat = await ConversationAdapters.get_offline_chat()
chat_model = offline_chat.chat_model
@@ -722,7 +704,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
- elif await ConversationAdapters.has_openai_chat():
+ elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key
@@ -744,9 +726,9 @@ async def extract_references_and_questions(
r=True,
max_distance=d,
dedupe=False,
+ common=common,
)
)
- # Dedupe the results again, as duplicates may be returned across queries.
result_list = text_search.deduplicated_search_responses(result_list)
compiled_references = [item.additional["compiled"] for item in result_list]
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 207c87a1..b609e977 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -6,10 +6,10 @@ from datetime import datetime
from functools import partial
import logging
from time import time
-from typing import Iterator, List, Optional, Union, Tuple, Dict, Any
+from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict, Any
# External Packages
-from fastapi import HTTPException, Request
+from fastapi import HTTPException, Header, Request, Depends
# Internal Packages
from khoj.utils import state
@@ -232,3 +232,20 @@ class ApiUserRateLimiter:
# Add the current request to the cache
user_requests.append(time())
+
+
+class CommonQueryParamsClass:
+ def __init__(
+ self,
+ client: Optional[str] = None,
+ user_agent: Optional[str] = Header(None),
+ referer: Optional[str] = Header(None),
+ host: Optional[str] = Header(None),
+ ):
+ self.client = client
+ self.user_agent = user_agent
+ self.referer = referer
+ self.host = host
+
+
+CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py
index a2cb0381..6c4cc9dd 100644
--- a/src/khoj/routers/indexer.py
+++ b/src/khoj/routers/indexer.py
@@ -63,7 +63,7 @@ async def update(
request: Request,
files: list[UploadFile],
force: bool = False,
- t: Optional[Union[state.SearchType, str]] = None,
+ t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
@@ -182,13 +182,16 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels,
regenerate: bool = False,
- t: Optional[state.SearchType] = None,
+ t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True,
user: KhojUser = None,
) -> tuple[Optional[ContentIndex], bool]:
content_index = ContentIndex()
success = True
+ if t is not None and t in [type.value for type in state.SearchType]:
+ t = state.SearchType(t)
+
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return None, False
@@ -201,7 +204,7 @@ def configure_content(
try:
# Initialize Org Notes Search
- if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
+ if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
text_search.setup(
@@ -217,7 +220,9 @@ def configure_content(
try:
# Initialize Markdown Search
- if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
+ if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
+ "markdown"
+ ]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
text_search.setup(
@@ -234,7 +239,7 @@ def configure_content(
try:
# Initialize PDF Search
- if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
+ if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
text_search.setup(
@@ -251,7 +256,9 @@ def configure_content(
try:
# Initialize Plaintext Search
- if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
+ if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
+ "plaintext"
+ ]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
text_search.setup(
@@ -269,7 +276,7 @@ def configure_content(
try:
# Initialize Image Search
if (
- (search_type == None or search_type == state.SearchType.Image.value)
+ (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
@@ -286,7 +293,9 @@ def configure_content(
try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
- if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
+ if (
+ search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
+ ) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
text_search.setup(
@@ -305,7 +314,9 @@ def configure_content(
try:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
- if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
+ if (
+ search_type == state.SearchType.All.value or search_type in state.SearchType.Notion.value
+ ) and notion_config:
logger.info("🔌 Setting up search for notion")
text_search.setup(
NotionToEntries,
diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py
index 214118fc..8c0a3cdb 100644
--- a/src/khoj/search_type/image_search.py
+++ b/src/khoj/search_type/image_search.py
@@ -229,7 +229,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
# Add the image metadata to the results
results += [
- SearchResponse.parse_obj(
+ SearchResponse.model_validate(
{
"entry": f"{image_files_url}/{target_image_name}",
"score": f"{hit['score']:.9f}",
@@ -237,7 +237,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
- "corpus_id": hit["corpus_id"],
+ "corpus_id": str(hit["corpus_id"]),
}
)
]
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index f07eb580..7e295903 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -163,7 +163,7 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
else:
hit_ids.add(hit.corpus_id)
- yield SearchResponse.parse_obj(
+ yield SearchResponse.model_validate(
{
"entry": hit.entry,
"score": hit.score,
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 799184ae..42e3835d 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -288,15 +288,15 @@ def generate_random_name():
# List of adjectives and nouns to choose from
adjectives = [
"happy",
- "irritated",
- "annoyed",
+ "serendipitous",
+ "exuberant",
"calm",
"brave",
"scared",
"energetic",
"chivalrous",
"kind",
- "grumpy",
+ "suave",
]
nouns = ["dog", "cat", "falcon", "whale", "turtle", "rabbit", "hamster", "snake", "spider", "elephant"]
diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py
index 48525ab2..4c97aedd 100644
--- a/src/khoj/utils/rawconfig.py
+++ b/src/khoj/utils/rawconfig.py
@@ -14,7 +14,7 @@ from khoj.utils.helpers import to_snake_case_from_dash
class ConfigBase(BaseModel):
class Config:
alias_generator = to_snake_case_from_dash
- allow_population_by_field_name = True
+ populate_by_name = True
def __getitem__(self, item):
return getattr(self, item)
@@ -29,8 +29,8 @@ class TextConfigBase(ConfigBase):
class TextContentConfig(ConfigBase):
- input_files: Optional[List[Path]]
- input_filter: Optional[List[str]]
+ input_files: Optional[List[Path]] = None
+ input_filter: Optional[List[str]] = None
index_heading_entries: Optional[bool] = False
@@ -50,31 +50,31 @@ class NotionContentConfig(ConfigBase):
class ImageContentConfig(ConfigBase):
- input_directories: Optional[List[Path]]
- input_filter: Optional[List[str]]
+ input_directories: Optional[List[Path]] = None
+ input_filter: Optional[List[str]] = None
embeddings_file: Path
use_xmp_metadata: bool
batch_size: int
class ContentConfig(ConfigBase):
- org: Optional[TextContentConfig]
- image: Optional[ImageContentConfig]
- markdown: Optional[TextContentConfig]
- pdf: Optional[TextContentConfig]
- plaintext: Optional[TextContentConfig]
- github: Optional[GithubContentConfig]
- notion: Optional[NotionContentConfig]
+ org: Optional[TextContentConfig] = None
+ image: Optional[ImageContentConfig] = None
+ markdown: Optional[TextContentConfig] = None
+ pdf: Optional[TextContentConfig] = None
+ plaintext: Optional[TextContentConfig] = None
+ github: Optional[GithubContentConfig] = None
+ notion: Optional[NotionContentConfig] = None
class ImageSearchConfig(ConfigBase):
encoder: str
- encoder_type: Optional[str]
- model_directory: Optional[Path]
+ encoder_type: Optional[str] = None
+ model_directory: Optional[Path] = None
class SearchConfig(ConfigBase):
- image: Optional[ImageSearchConfig]
+ image: Optional[ImageSearchConfig] = None
class OpenAIProcessorConfig(ConfigBase):
@@ -95,26 +95,26 @@ class ConversationProcessorConfig(ConfigBase):
class ProcessorConfig(ConfigBase):
- conversation: Optional[ConversationProcessorConfig]
+ conversation: Optional[ConversationProcessorConfig] = None
class AppConfig(ConfigBase):
- should_log_telemetry: bool
+ should_log_telemetry: bool = True
class FullConfig(ConfigBase):
content_type: Optional[ContentConfig] = None
search_type: Optional[SearchConfig] = None
processor: Optional[ProcessorConfig] = None
- app: Optional[AppConfig] = AppConfig(should_log_telemetry=True)
+ app: Optional[AppConfig] = AppConfig()
version: Optional[str] = None
class SearchResponse(ConfigBase):
entry: str
score: float
- cross_score: Optional[float]
- additional: Optional[dict]
+ cross_score: Optional[float] = None
+ additional: Optional[dict] = None
corpus_id: str
diff --git a/src/khoj/utils/yaml.py b/src/khoj/utils/yaml.py
index abfe270a..36546688 100644
--- a/src/khoj/utils/yaml.py
+++ b/src/khoj/utils/yaml.py
@@ -39,7 +39,7 @@ def load_config_from_file(yaml_config_file: Path) -> dict:
def parse_config_from_string(yaml_config: dict) -> FullConfig:
"Parse and validate config in YML string"
- return FullConfig.parse_obj(yaml_config)
+ return FullConfig.model_validate(yaml_config)
def parse_config_from_file(yaml_config_file):
diff --git a/tests/conftest.py b/tests/conftest.py
index 16f0ef1b..d3a27748 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -9,9 +9,6 @@ import os
from fastapi import FastAPI
-app = FastAPI()
-
-
# Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
@@ -320,6 +317,7 @@ def client(
state.anonymous_mode = False
+ app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index a8c85787..07c4e0d8 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -227,7 +227,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
# Assert
assert response.status_code == 200
- assert response_message == prompts.no_notes_found.format()
+ assert response_message == prompts.no_entries_found.format()
# ----------------------------------------------------------------------------------------------------
diff --git a/versions.json b/versions.json
index 8deb4367..06efeecb 100644
--- a/versions.json
+++ b/versions.json
@@ -26,5 +26,6 @@
"0.12.2": "0.15.0",
"0.12.3": "0.15.0",
"0.13.0": "0.15.0",
- "0.14.0": "0.15.0"
+ "0.14.0": "0.15.0",
+ "1.0.0": "0.15.0"
}