diff --git a/gunicorn-config.py b/gunicorn-config.py
index bfed49e7..ea382346 100644
--- a/gunicorn-config.py
+++ b/gunicorn-config.py
@@ -1,10 +1,10 @@
import multiprocessing
bind = "0.0.0.0:42110"
-workers = 8
+workers = 1
worker_class = "uvicorn.workers.UvicornWorker"
timeout = 120
keep_alive = 60
-accesslog = "access.log"
-errorlog = "error.log"
+accesslog = "-"
+errorlog = "-"
loglevel = "debug"
diff --git a/prod.Dockerfile b/prod.Dockerfile
index 413835d0..0da5363a 100644
--- a/prod.Dockerfile
+++ b/prod.Dockerfile
@@ -1,12 +1,9 @@
-# Use Nvidia's latest Ubuntu 22.04 image as the base image
-FROM nvidia/cuda:12.2.0-devel-ubuntu22.04
+FROM ubuntu:jammy
LABEL org.opencontainers.image.source https://github.com/khoj-ai/khoj
# Install System Dependencies
RUN apt update -y && apt -y install python3-pip libsqlite3-0 ffmpeg libsm6 libxext6
-# Install Optional Dependencies
-RUN apt install vim -y
WORKDIR /app
diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index 94cde782..f37ae562 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -87,7 +87,7 @@
function generateOnlineReference(reference, index) {
// Generate HTML for Chat Reference
- let title = reference.title;
+ let title = reference.title || reference.link;
let link = reference.link;
let snippet = reference.snippet;
let question = reference.question;
@@ -191,6 +191,15 @@
referenceSection.appendChild(polishedReference);
}
}
+
+ if (onlineReference.webpages && onlineReference.webpages.length > 0) {
+ numOnlineReferences += onlineReference.webpages.length;
+ for (let index in onlineReference.webpages) {
+ let reference = onlineReference.webpages[index];
+ let polishedReference = generateOnlineReference(reference, index);
+ referenceSection.appendChild(polishedReference);
+ }
+ }
}
return numOnlineReferences;
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index fb3a93ce..0adbe889 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -268,6 +268,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
def configure_routes(app):
# Import APIs here to setup search types before while configuring server
from khoj.routers.api import api
+ from khoj.routers.api_agents import api_agents
from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config
from khoj.routers.indexer import indexer
@@ -275,6 +276,7 @@ def configure_routes(app):
app.include_router(api, prefix="/api")
app.include_router(api_chat, prefix="/api/chat")
+ app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_config, prefix="/api/config")
app.include_router(indexer, prefix="/api/v1/index")
app.include_router(web_client)
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index cb317275..25f781fe 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -394,20 +394,32 @@ class ClientApplicationAdapters:
class AgentAdapters:
- DEFAULT_AGENT_NAME = "khoj"
+ DEFAULT_AGENT_NAME = "Khoj"
DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png"
+ DEFAULT_AGENT_SLUG = "khoj"
@staticmethod
- async def aget_agent_by_id(agent_id: int, user: KhojUser):
- agent = await Agent.objects.filter(id=agent_id).afirst()
- # Check if it's accessible to the user
- if agent and (agent.public or agent.creator == user):
- return agent
- return None
+ async def aget_agent_by_slug(agent_slug: str, user: KhojUser):
+ return await Agent.objects.filter(
+ (Q(slug__iexact=agent_slug.lower())) & (Q(public=True) | Q(creator=user))
+ ).afirst()
+
+ @staticmethod
+ def get_agent_by_slug(slug: str, user: KhojUser = None):
+ if user:
+ return Agent.objects.filter((Q(slug__iexact=slug.lower())) & (Q(public=True) | Q(creator=user))).first()
+ return Agent.objects.filter(slug__iexact=slug.lower(), public=True).first()
@staticmethod
def get_all_accessible_agents(user: KhojUser = None):
- return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct()
+ if user:
+ return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at")
+ return Agent.objects.filter(public=True).order_by("created_at")
+
+ @staticmethod
+ async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]:
+ agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
+ return await sync_to_async(list)(agents)
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
@@ -423,12 +435,19 @@ class AgentAdapters:
@staticmethod
def create_default_agent():
- # First delete the existing default
- Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete()
-
default_conversation_config = ConversationAdapters.get_default_conversation_config()
default_personality = prompts.personality.format(current_date="placeholder")
+ agent = Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
+
+ if agent:
+ agent.personality = default_personality
+ agent.chat_model = default_conversation_config
+ agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG
+ agent.name = AgentAdapters.DEFAULT_AGENT_NAME
+ agent.save()
+ return agent
+
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
return Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
@@ -438,6 +457,7 @@ class AgentAdapters:
personality=default_personality,
tools=["*"],
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
+ slug=AgentAdapters.DEFAULT_AGENT_SLUG,
)
@staticmethod
@@ -486,10 +506,12 @@ class ConversationAdapters:
@staticmethod
async def acreate_conversation_session(
- user: KhojUser, client_application: ClientApplication = None, agent_id: int = None
+ user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None
):
- if agent_id:
- agent = await AgentAdapters.aget_agent_by_id(agent_id, user)
+ if agent_slug:
+ agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
+ if agent is None:
+ raise HTTPException(status_code=400, detail="No such agent currently exists.")
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
return await Conversation.objects.acreate(user=user, client=client_application)
diff --git a/src/khoj/database/migrations/0031_agent_conversation_agent.py b/src/khoj/database/migrations/0031_agent_conversation_agent.py
index 16586499..1d08a118 100644
--- a/src/khoj/database/migrations/0031_agent_conversation_agent.py
+++ b/src/khoj/database/migrations/0031_agent_conversation_agent.py
@@ -1,4 +1,4 @@
-# Generated by Django 4.2.10 on 2024-03-11 05:12
+# Generated by Django 4.2.10 on 2024-03-13 07:38
import django.db.models.deletion
from django.conf import settings
@@ -23,6 +23,7 @@ class Migration(migrations.Migration):
("tools", models.JSONField(default=list)),
("public", models.BooleanField(default=False)),
("managed_by_admin", models.BooleanField(default=False)),
+ ("slug", models.CharField(max_length=200)),
(
"chat_model",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodeloptions"),
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index b8eeb8b1..364a6d1a 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -1,4 +1,5 @@
import uuid
+from random import choice
from django.contrib.auth.models import AbstractUser
from django.core.exceptions import ValidationError
@@ -94,13 +95,28 @@ class Agent(BaseModel):
public = models.BooleanField(default=False)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
+ slug = models.CharField(max_length=200)
@receiver(pre_save, sender=Agent)
-def check_public_name(sender, instance, **kwargs):
- if instance.public:
+def verify_agent(sender, instance, **kwargs):
+ # check if this is a new instance
+ if instance._state.adding:
if Agent.objects.filter(name=instance.name, public=True).exists():
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
+ if Agent.objects.filter(name=instance.name, creator=instance.creator).exists():
+ raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
+
+ slug = instance.name.lower().replace(" ", "-")
+ observed_random_numbers = set()
+ while Agent.objects.filter(slug=slug).exists():
+ try:
+ random_number = choice([i for i in range(0, 1000) if i not in observed_random_numbers])
+ except IndexError:
+ raise ValidationError("Unable to generate a unique slug for the Agent. Please try again later.")
+ observed_random_numbers.add(random_number)
+ slug = f"{slug}-{random_number}"
+ instance.slug = slug
class NotionConfig(BaseModel):
diff --git a/src/khoj/interface/web/404.html b/src/khoj/interface/web/404.html
index 7041ff80..0762bde8 100644
--- a/src/khoj/interface/web/404.html
+++ b/src/khoj/interface/web/404.html
@@ -2,14 +2,19 @@
Khoj: An AI Personal Assistant for your digital brain
-
-
+
+
+
+
+ {% import 'utils.html' as utils %}
+ {{ utils.heading_pane(user_photo, username, is_active, has_documents) }}
+
- Go Home
+ Go Home
@@ -18,5 +23,34 @@
body.not-found {
padding: 0 10%
}
+
+ body {
+ background-color: var(--background-color);
+ color: var(--main-text-color);
+ text-align: center;
+ font-family: var(--font-family);
+ font-size: medium;
+ font-weight: 300;
+ line-height: 1.5em;
+ height: 100vh;
+ margin: 0;
+ }
+
+ body a.redirect-link {
+ font-size: 18px;
+ font-weight: bold;
+ background-color: var(--primary);
+ text-decoration: none;
+ border: 1px solid var(--main-text-color);
+ color: var(--main-text-color);
+ border-radius: 8px;
+ padding: 4px;
+ }
+
+ body a.redirect-link:hover {
+ background-color: var(--main-text-color);
+ color: var(--primary);
+ }
+
diff --git a/src/khoj/interface/web/agent.html b/src/khoj/interface/web/agent.html
new file mode 100644
index 00000000..6e6619cf
--- /dev/null
+++ b/src/khoj/interface/web/agent.html
@@ -0,0 +1,286 @@
+
+
+
+
+ Khoj - Agents
+
+
+
+
+
+
+
+
+ {% import 'utils.html' as utils %}
+ {{ utils.heading_pane(user_photo, username, is_active, has_documents) }}
+
+
+
+
+
+
diff --git a/src/khoj/interface/web/agents.html b/src/khoj/interface/web/agents.html
new file mode 100644
index 00000000..dc26606c
--- /dev/null
+++ b/src/khoj/interface/web/agents.html
@@ -0,0 +1,201 @@
+
+
+
+
+ Khoj - Agents
+
+
+
+
+
+
+
+
+ {% import 'utils.html' as utils %}
+ {{ utils.heading_pane(user_photo, username, is_active, has_documents) }}
+
+
+
+
+
+ {% for agent in agents %}
+
+ {% endfor %}
+
+
+
+
+
+
+
diff --git a/src/khoj/interface/web/assets/khoj.css b/src/khoj/interface/web/assets/khoj.css
index 7ba93c6a..3d7e7d4a 100644
--- a/src/khoj/interface/web/assets/khoj.css
+++ b/src/khoj/interface/web/assets/khoj.css
@@ -130,7 +130,7 @@ img.khoj-logo {
background-color: var(--background-color);
min-width: 160px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
- right: 15vw;
+ right: 5vw;
top: 64px;
z-index: 1;
opacity: 0;
diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html
index 5c26e060..870f4eb9 100644
--- a/src/khoj/interface/web/base_config.html
+++ b/src/khoj/interface/web/base_config.html
@@ -162,7 +162,7 @@
height: 40px;
}
.card-title {
- font-size: 20px;
+ font-size: medium;
font-weight: normal;
margin: 0;
padding: 0;
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index 35047c31..52750c12 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -5,6 +5,7 @@
Khoj - Chat
+
@@ -12,15 +13,16 @@
@@ -1205,13 +1328,27 @@ To get started, just start typing below. You can also type / to see a list of co
+
+
+
+
* {
@@ -1891,6 +2033,7 @@ To get started, just start typing below. You can also type / to see a list of co
div#new-conversation {
text-align: left;
border-bottom: 1px solid var(--main-text-color);
+ margin-top: 8px;
margin-bottom: 8px;
}
@@ -2046,6 +2189,170 @@ To get started, just start typing below. You can also type / to see a list of co
animation-delay: -0.5s;
}
+ #agent-metadata-content {
+ display: grid;
+ grid-template-columns: auto 1fr;
+ padding: 10px;
+ background-color: var(--primary);
+ border-radius: 5px;
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
+ margin-bottom: 20px;
+ }
+
+ #agent-metadata {
+ border-top: 1px solid black;
+ padding-top: 10px;
+ }
+
+ #agent-avatar-wrapper {
+ margin-right: 10px;
+ }
+
+ #agent-avatar {
+ width: 50px;
+ height: 50px;
+ border-radius: 50%;
+ object-fit: cover;
+ }
+
+ #agent-name-wrapper {
+ display: grid;
+ align-items: center;
+ }
+
+ #agent-name {
+ font-size: 18px;
+ font-weight: bold;
+ color: #333;
+ }
+
+ #agent-instructions {
+ font-size: 14px;
+ color: #666;
+ height: 50px;
+ overflow: auto;
+ }
+
+ #agent-owned-by-user {
+ font-size: 12px;
+ color: #007BFF;
+ margin-top: 5px;
+ }
+
+ .modal {
+ position: fixed; /* Stay in place */
+ z-index: 1; /* Sit on top */
+ left: 0;
+ top: 0;
+ width: 100%; /* Full width */
+ height: 100%; /* Full height */
+ background-color: rgba(0,0,0,0.4); /* Black w/ opacity */
+ margin: 0px;
+ }
+
+ .modal-content {
+ margin: 15% auto; /* 15% from the top and centered */
+ padding: 20px;
+ border: 1px solid #888;
+ width: 250px;
+ text-align: left;
+ background: var(--background-color);
+ border-radius: 5px;
+ box-shadow: 0 0 11px #aaa;
+ text-align: left;
+ }
+
+ .modal-header {
+ display: grid;
+ grid-template-columns: 1fr auto;
+ color: var(--main-text-color);
+ align-items: baseline;
+ }
+
+ .modal-header h2 {
+ margin: 0;
+ text-align: left;
+ }
+
+ .modal-body {
+ display: grid;
+ grid-auto-flow: row;
+ gap: 8px;
+ }
+
+ .modal-body a {
+ /* text-decoration: none; */
+ color: var(--summer-sun);
+ }
+
+ .modal-close-button {
+ margin: 0;
+ font-size: 20px;
+ background: none;
+ border: none;
+ color: var(--summer-sun);
+ }
+
+ .modal-close-button:hover,
+ .modal-close-button:focus {
+ color: #000;
+ text-decoration: none;
+ cursor: pointer;
+ }
+
+ #new-conversation-form {
+ display: flex;
+ flex-direction: column;
+ }
+
+ #new-conversation-form label,
+ #new-conversation-form input,
+ #new-conversation-form button {
+ margin-bottom: 10px;
+ }
+
+ #new-conversation-form button {
+ cursor: pointer;
+ }
+
+ .modal-footer {
+ display: grid;
+ grid-template-columns: 1fr 1fr;
+ grid-gap: 12px;
+ }
+
+ .modal-body button {
+ cursor: pointer;
+ border-radius: 12px;
+ padding: 8px;
+ border: 1px solid var(--main-text-color);
+ }
+
+ button#new-conversation-submit-button {
+ background: var(--summer-sun);
+ transition: background 0.2s ease-in-out;
+ }
+
+ button#close-button {
+ background: var(--background-color);
+ transition: background 0.2s ease-in-out;
+ }
+
+ button#new-conversation-submit-button:hover {
+ background: var(--primary);
+ }
+
+ button#close-button:hover {
+ background: var(--primary-hover);
+ }
+
+ .modal-body select {
+ padding: 8px;
+ border-radius: 12px;
+ border: 1px solid var(--main-text-color);
+ }
+
+
@keyframes lds-ripple {
0% {
top: 36px;
diff --git a/src/khoj/interface/web/search.html b/src/khoj/interface/web/search.html
index 8bbd9f32..100450c7 100644
--- a/src/khoj/interface/web/search.html
+++ b/src/khoj/interface/web/search.html
@@ -5,6 +5,7 @@
Khoj - Search
+
diff --git a/src/khoj/interface/web/utils.html b/src/khoj/interface/web/utils.html
index edd75fc1..aa65be90 100644
--- a/src/khoj/interface/web/utils.html
+++ b/src/khoj/interface/web/utils.html
@@ -9,26 +9,28 @@
🔎 Search
{% endif %}
-
{%- endmacro %}
diff --git a/src/khoj/main.py b/src/khoj/main.py
index 5099002f..a9e333d2 100644
--- a/src/khoj/main.py
+++ b/src/khoj/main.py
@@ -160,7 +160,9 @@ def start_server(app, host=None, port=None, socket=None):
if socket:
uvicorn.run(app, proxy_headers=True, uds=socket, log_level="debug", use_colors=True, log_config=None)
else:
- uvicorn.run(app, host=host, port=port, log_level="debug", use_colors=True, log_config=None)
+ uvicorn.run(
+ app, host=host, port=port, log_level="debug", use_colors=True, log_config=None, timeout_keep_alive=60
+ )
logger.info("🌒 Stopping Khoj")
diff --git a/src/khoj/processor/content/markdown/markdown_to_entries.py b/src/khoj/processor/content/markdown/markdown_to_entries.py
index 4cce4876..7274cf1c 100644
--- a/src/khoj/processor/content/markdown/markdown_to_entries.py
+++ b/src/khoj/processor/content/markdown/markdown_to_entries.py
@@ -114,7 +114,7 @@ class MarkdownToEntries(TextToEntries):
# Append base filename to compiled entry for context to model
# Increment heading level for heading entries and make filename as its top level heading
prefix = f"# {stem}\n#" if heading else f"# {stem}\n"
- compiled_entry = f"{prefix}{parsed_entry}"
+ compiled_entry = f"{entry_filename}\n{prefix}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
index 6348bb1a..d4d28b6f 100644
--- a/src/khoj/processor/conversation/offline/chat_model.py
+++ b/src/khoj/processor/conversation/offline/chat_model.py
@@ -188,8 +188,8 @@ def converse_offline(
if ConversationCommand.Online in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
- if online_results[result].get("extracted_content"):
- simplified_online_results[result] = online_results[result]["extracted_content"]
+ if online_results[result].get("webpages"):
+ simplified_online_results[result] = online_results[result]["webpages"]
conversation_primer = f"{prompts.online_search_conversation.format(online_results=str(simplified_online_results))}\n{conversation_primer}"
if not is_none_or_empty(compiled_references_message):
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index a76fd6e9..584037ed 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -1,7 +1,7 @@
import json
import logging
from datetime import datetime, timedelta
-from typing import Optional
+from typing import Dict, Optional
from langchain.schema import ChatMessage
@@ -105,7 +105,7 @@ def send_message_to_model(messages, api_key, model, response_type="text"):
def converse(
references,
user_query,
- online_results: Optional[dict] = None,
+ online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
@@ -151,7 +151,7 @@ def converse(
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
- if ConversationCommand.Online in conversation_commands:
+ if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
@@ -167,7 +167,7 @@ def converse(
max_prompt_size,
tokenizer_name,
)
- truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
+ truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
# Get Response from GPT
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 3bdb683a..b500c1ec 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -10,7 +10,7 @@ You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- Users can share files and other information with you using the Khoj Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window.
-- You can generate images, look-up information from the internet, and answer questions based on the user's notes.
+- You *CAN* generate images, look-up real-time information from the internet, and answer questions based on the user's notes.
- You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
@@ -23,7 +23,7 @@ Today is {current_date} in UTC.
custom_personality = PromptTemplate.from_template(
"""
-Your are {name}, a personal agent on Khoj.
+You are {name}, a personal agent on Khoj.
Use your general knowledge and past conversation with the user as context to inform your responses.
You were created by Khoj Inc. with the following capabilities:
@@ -178,7 +178,8 @@ online_search_conversation = PromptTemplate.from_template(
Use this up-to-date information from the internet to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the online data or past conversations.
-Information from the internet: {online_results}
+Information from the internet:
+{online_results}
""".strip()
)
@@ -312,7 +313,7 @@ Target Query: {query}
Web Pages:
{corpus}
-Collate the relevant information from the website to answer the target query.
+Collate only relevant information from the website to answer the target query.
""".strip()
)
@@ -394,6 +395,14 @@ AI: Good morning! How can I help you today?
Q: How can I share my files with Khoj?
Khoj: {{"source": ["default", "online"]}}
+Example:
+Chat History:
+User: What is the first element in the periodic table?
+AI: The first element in the periodic table is Hydrogen.
+
+Q: Summarize this article https://en.wikipedia.org/wiki/Hydrogen
+Khoj: {{"source": ["webpage"]}}
+
Example:
Chat History:
User: I want to start a new hobby. I'm thinking of learning to play the guitar.
@@ -412,6 +421,50 @@ Khoj:
""".strip()
)
+infer_webpages_to_read = PromptTemplate.from_template(
+ """
+You are Khoj, an advanced web page reading assistant. You are to construct **up to three, valid** webpage urls to read before answering the user's question.
+- You will receive the conversation history as context.
+- Add as much context from the previous questions and answers as required to construct the webpage urls.
+- Use multiple web page urls if required to retrieve the relevant information.
+- You have access to the the whole internet to retrieve information.
+
+Which webpages will you need to read to answer the user's question?
+Provide web page links as a list of strings in a JSON object.
+Current Date: {current_date}
+User's Location: {location}
+
+Here are some examples:
+History:
+User: I like to use Hacker News to get my tech news.
+AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
+
+Q: Summarize this post about vector database on Hacker News, https://news.ycombinator.com/item?id=12345
+Khoj: {{"links": ["https://news.ycombinator.com/item?id=12345"]}}
+
+History:
+User: I'm currently living in New York but I'm thinking about moving to San Francisco.
+AI: New York is a great city to live in. It has a lot of great restaurants and museums. San Francisco is also a great city to live in. It has good access to nature and a great tech scene.
+
+Q: What is the climate like in those cities?
+Khoj: {{"links": ["https://en.wikipedia.org/wiki/New_York_City", "https://en.wikipedia.org/wiki/San_Francisco"]}}
+
+History:
+User: Hey, how is it going?
+AI: Not too bad. How can I help you today?
+
+Q: What's the latest news on r/worldnews?
+Khoj: {{"links": ["https://www.reddit.com/r/worldnews/"]}}
+
+Now it's your turn to share actual webpage urls you'd like to read to answer the user's question.
+History:
+{chat_history}
+
+Q: {query}
+Khoj:
+""".strip()
+)
+
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an advanced google search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py
index 597f394e..c9745dc9 100644
--- a/src/khoj/processor/tools/online_search.py
+++ b/src/khoj/processor/tools/online_search.py
@@ -2,6 +2,7 @@ import asyncio
import json
import logging
import os
+from collections import defaultdict
from typing import Dict, Tuple, Union
import aiohttp
@@ -9,7 +10,11 @@ import requests
from bs4 import BeautifulSoup
from markdownify import markdownify
-from khoj.routers.helpers import extract_relevant_info, generate_online_subqueries
+from khoj.routers.helpers import (
+ extract_relevant_info,
+ generate_online_subqueries,
+ infer_webpage_urls,
+)
from khoj.utils.helpers import is_none_or_empty, timer
from khoj.utils.rawconfig import LocationData
@@ -38,7 +43,7 @@ MAX_WEBPAGES_TO_READ = 1
async def search_online(query: str, conversation_history: dict, location: LocationData):
- if SERPER_DEV_API_KEY is None:
+ if not online_search_enabled():
logger.warn("SERPER_DEV_API_KEY is not set")
return {}
@@ -52,24 +57,21 @@ async def search_online(query: str, conversation_history: dict, location: Locati
# Gather distinct web pages from organic search results of each subquery without an instant answer
webpage_links = {
- result["link"]
+ organic["link"]: subquery
for subquery in response_dict
- for result in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
+ for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
if "answerBox" not in response_dict[subquery]
}
# Read, extract relevant info from the retrieved web pages
- tasks = []
- for webpage_link in webpage_links:
- logger.info(f"Reading web page at '{webpage_link}'")
- task = read_webpage_and_extract_content(subquery, webpage_link)
- tasks.append(task)
+ logger.info(f"Reading web pages at: {webpage_links.keys()}")
+ tasks = [read_webpage_and_extract_content(subquery, link) for link, subquery in webpage_links.items()]
results = await asyncio.gather(*tasks)
# Collect extracted info from the retrieved web pages
- for subquery, extracted_webpage_content in results:
- if extracted_webpage_content is not None:
- response_dict[subquery]["extracted_content"] = extracted_webpage_content
+ for subquery, webpage_extract, url in results:
+ if webpage_extract is not None:
+ response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
return response_dict
@@ -93,19 +95,35 @@ def search_with_google(subquery: str):
return extracted_search_result
-async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str]]:
+async def read_webpages(query: str, conversation_history: dict, location: LocationData):
+ "Infer web pages to read from the query and extract relevant information from them"
+ logger.info(f"Inferring web pages to read")
+ urls = await infer_webpage_urls(query, conversation_history, location)
+
+ logger.info(f"Reading web pages at: {urls}")
+ tasks = [read_webpage_and_extract_content(query, url) for url in urls]
+ results = await asyncio.gather(*tasks)
+
+ response: Dict[str, Dict] = defaultdict(dict)
+ response[query]["webpages"] = [
+ {"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
+ ]
+ return response
+
+
+async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str], str]:
try:
with timer(f"Reading web page at '{url}' took", logger):
- content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage(url)
+ content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_at_url(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content)
- return subquery, extracted_info
+ return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")
- return subquery, None
+ return subquery, None, url
-async def read_webpage(web_url: str) -> str:
+async def read_webpage_at_url(web_url: str) -> str:
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.97 Safari/537.36",
}
@@ -129,3 +147,7 @@ async def read_webpage_with_olostep(web_url: str) -> str:
response.raise_for_status()
response_json = await response.json()
return response_json["markdown_content"]
+
+
+def online_search_enabled():
+ return SERPER_DEV_API_KEY is not None
diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py
new file mode 100644
index 00000000..14b86217
--- /dev/null
+++ b/src/khoj/routers/api_agents.py
@@ -0,0 +1,43 @@
+import json
+import logging
+
+from fastapi import APIRouter, Request
+from fastapi.requests import Request
+from fastapi.responses import Response
+
+from khoj.database.adapters import AgentAdapters
+from khoj.database.models import KhojUser
+from khoj.routers.helpers import CommonQueryParams
+
+# Initialize Router
+logger = logging.getLogger(__name__)
+
+
+api_agents = APIRouter()
+
+
+@api_agents.get("/", response_class=Response)
+async def all_agents(
+ request: Request,
+ common: CommonQueryParams,
+) -> Response:
+ user: KhojUser = request.user.object if request.user.is_authenticated else None
+ agents = await AgentAdapters.aget_all_accessible_agents(user)
+ agents_packet = list()
+ for agent in agents:
+ agents_packet.append(
+ {
+ "slug": agent.slug,
+ "avatar": agent.avatar,
+ "name": agent.name,
+ "personality": agent.personality,
+ "public": agent.public,
+ "creator": agent.creator.username if agent.creator else None,
+ "managed_by_admin": agent.managed_by_admin,
+ }
+ )
+
+ # Make sure that the agent named 'khoj' is first in the list. Everything else is sorted by name.
+ agents_packet.sort(key=lambda x: x["name"])
+ agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
+ return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index e9d87f92..470e8946 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -12,9 +12,17 @@ from starlette.authentication import requires
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
from khoj.database.models import KhojUser
-from khoj.processor.conversation.prompts import help_message, no_entries_found
+from khoj.processor.conversation.prompts import (
+ help_message,
+ no_entries_found,
+ no_notes_found,
+)
from khoj.processor.conversation.utils import save_to_conversation_log
-from khoj.processor.tools.online_search import search_online
+from khoj.processor.tools.online_search import (
+ online_search_enabled,
+ read_webpages,
+ search_online,
+)
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
@@ -81,9 +89,22 @@ def chat_history(
status_code=404,
)
+ agent_metadata = None
+ if conversation.agent:
+ agent_metadata = {
+ "slug": conversation.agent.slug,
+ "name": conversation.agent.name,
+ "avatar": conversation.agent.avatar,
+ "isCreator": conversation.agent.creator == user,
+ }
+
meta_log = conversation.conversation_log
meta_log.update(
- {"conversation_id": conversation.id, "slug": conversation.title if conversation.title else conversation.slug}
+ {
+ "conversation_id": conversation.id,
+ "slug": conversation.title if conversation.title else conversation.slug,
+ "agent": agent_metadata,
+ }
)
update_telemetry_state(
@@ -148,12 +169,12 @@ def chat_sessions(
async def create_chat_session(
request: Request,
common: CommonQueryParams,
- agent_id: Optional[int] = None,
+ agent_slug: Optional[str] = None,
):
user = request.user.object
# Create new Conversation Session
- conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app, agent_id)
+ conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app, agent_slug)
response = {"conversation_id": conversation.id}
@@ -239,6 +260,7 @@ async def chat(
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
+ logger.info(f"Chat request by {user.username}: {q}")
await is_ready_to_chat(user)
conversation_commands = [get_conversation_command(query=q, any_references=True)]
@@ -281,7 +303,7 @@ async def chat(
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
)
- online_results: Dict = dict()
+ online_results: Dict[str, Dict] = {}
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
@@ -291,17 +313,35 @@ async def chat(
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
+ if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
+ no_notes_found_format = no_notes_found.format()
+ if stream:
+ return StreamingResponse(iter([no_notes_found_format]), media_type="text/event-stream", status_code=200)
+ else:
+ response_obj = {"response": no_notes_found_format}
+ return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
+
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
+ if not online_search_enabled():
+ conversation_commands.remove(ConversationCommand.Online)
+ # If online search is not enabled, try to read webpages directly
+ if ConversationCommand.Webpage not in conversation_commands:
+ conversation_commands.append(ConversationCommand.Webpage)
+ else:
+ try:
+ online_results = await search_online(defiltered_query, meta_log, location)
+ except ValueError as e:
+ logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
+
+ if ConversationCommand.Webpage in conversation_commands:
try:
- online_results = await search_online(defiltered_query, meta_log, location)
+ online_results = await read_webpages(defiltered_query, meta_log, location)
except ValueError as e:
- return StreamingResponse(
- iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
- media_type="text/event-stream",
- status_code=200,
+ logger.warning(
+ f"Error directly reading webpages: {e}. Attempting to respond without online results", exc_info=True
)
if ConversationCommand.Image in conversation_commands:
diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py
index 89fef85b..1d7bbfdd 100644
--- a/src/khoj/routers/auth.py
+++ b/src/khoj/routers/auth.py
@@ -7,6 +7,7 @@ from starlette.authentication import requires
from starlette.config import Config
from starlette.requests import Request
from starlette.responses import HTMLResponse, RedirectResponse, Response
+from starlette.status import HTTP_302_FOUND
from khoj.database.adapters import (
create_khoj_token,
@@ -90,6 +91,7 @@ async def delete_token(request: Request, token: str) -> str:
@auth_router.post("/redirect")
async def auth(request: Request):
form = await request.form()
+ next_url = request.query_params.get("next", "/")
credential = form.get("credential")
csrf_token_cookie = request.cookies.get("g_csrf_token")
@@ -117,9 +119,9 @@ async def auth(request: Request):
metadata={"user_id": str(khoj_user.uuid)},
)
logger.log(logging.INFO, f"New User Created: {khoj_user.uuid}")
- RedirectResponse(url="/?status=welcome")
+ return RedirectResponse(url=f"{next_url}", status_code=HTTP_302_FOUND)
- return RedirectResponse(url="/")
+ return RedirectResponse(url=f"{next_url}")
@auth_router.get("/logout")
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index bdcba09d..bc221859 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -37,6 +37,7 @@ from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import (
ConversationCommand,
is_none_or_empty,
+ is_valid_url,
log_telemetry,
mode_descriptions_for_llm,
timer,
@@ -168,7 +169,8 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
chat_history=chat_history,
)
- response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object")
+ with timer("Chat actor: Infer information sources to refer", logger):
+ response = await send_message_to_model_wrapper(relevant_tools_prompt, response_type="json_object")
try:
response = response.strip()
@@ -212,7 +214,8 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict):
chat_history=chat_history,
)
- response = await send_message_to_model_wrapper(relevant_mode_prompt)
+ with timer("Chat actor: Infer output mode for chat response", logger):
+ response = await send_message_to_model_wrapper(relevant_mode_prompt)
try:
response = response.strip()
@@ -230,6 +233,36 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict):
return ConversationCommand.Default
+async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
+ """
+ Infer webpage links from the given query
+ """
+ location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
+ chat_history = construct_chat_history(conversation_history)
+
+ utc_date = datetime.utcnow().strftime("%Y-%m-%d")
+ online_queries_prompt = prompts.infer_webpages_to_read.format(
+ current_date=utc_date,
+ query=q,
+ chat_history=chat_history,
+ location=location,
+ )
+
+ with timer("Chat actor: Infer webpage urls to read", logger):
+ response = await send_message_to_model_wrapper(online_queries_prompt, response_type="json_object")
+
+ # Validate that the response is a non-empty, JSON-serializable list of URLs
+ try:
+ response = response.strip()
+ urls = json.loads(response)
+ valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
+ if is_none_or_empty(valid_unique_urls):
+ raise ValueError(f"Invalid list of urls: {response}")
+ return list(valid_unique_urls)
+ except Exception:
+ raise ValueError(f"Invalid list of urls: {response}")
+
+
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
"""
Generate subqueries from the given query
@@ -245,7 +278,8 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
location=location,
)
- response = await send_message_to_model_wrapper(online_queries_prompt, response_type="json_object")
+ with timer("Chat actor: Generate online search subqueries", logger):
+ response = await send_message_to_model_wrapper(online_queries_prompt, response_type="json_object")
# Validate that the response is a non-empty, JSON-serializable list
try:
@@ -274,9 +308,10 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
corpus=corpus.strip(),
)
- response = await send_message_to_model_wrapper(
- extract_relevant_information, prompts.system_prompt_extract_relevant_information
- )
+ with timer("Chat actor: Extract relevant information from data", logger):
+ response = await send_message_to_model_wrapper(
+ extract_relevant_information, prompts.system_prompt_extract_relevant_information
+ )
return response.strip()
@@ -305,8 +340,8 @@ async def generate_better_image_prompt(
for result in online_results:
if online_results[result].get("answerBox"):
simplified_online_results[result] = online_results[result]["answerBox"]
- elif online_results[result].get("extracted_content"):
- simplified_online_results[result] = online_results[result]["extracted_content"]
+ elif online_results[result].get("webpages"):
+ simplified_online_results[result] = online_results[result]["webpages"]
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
@@ -317,7 +352,8 @@ async def generate_better_image_prompt(
online_results=simplified_online_results,
)
- response = await send_message_to_model_wrapper(image_prompt)
+ with timer("Chat actor: Generate contextual image prompt", logger):
+ response = await send_message_to_model_wrapper(image_prompt)
return response.strip()
@@ -367,7 +403,7 @@ def generate_chat_response(
meta_log: dict,
conversation: Conversation,
compiled_references: List[str] = [],
- online_results: Dict[str, Any] = {},
+ online_results: Dict[str, Dict] = {},
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index ddbcc283..cb03cb89 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -10,6 +10,7 @@ from starlette.authentication import has_required_scope, requires
from khoj.database import adapters
from khoj.database.adapters import (
+ AgentAdapters,
ConversationAdapters,
EntryAdapters,
get_user_github_config,
@@ -114,8 +115,8 @@ def chat_page(request: Request):
@web_client.get("/login", response_class=FileResponse)
def login_page(request: Request):
+ next_url = request.query_params.get("next", "/")
if request.user.is_authenticated:
- next_url = request.query_params.get("next", "/")
return RedirectResponse(url=next_url)
google_client_id = os.environ.get("GOOGLE_CLIENT_ID")
redirect_uri = str(request.app.url_path_for("auth"))
@@ -124,7 +125,85 @@ def login_page(request: Request):
context={
"request": request,
"google_client_id": google_client_id,
- "redirect_uri": redirect_uri,
+ "redirect_uri": f"{redirect_uri}?next={next_url}",
+ },
+ )
+
+
+@web_client.get("/agents", response_class=HTMLResponse)
+def agents_page(request: Request):
+ user: KhojUser = request.user.object if request.user.is_authenticated else None
+ user_picture = request.session.get("user", {}).get("picture") if user else None
+ agents = AgentAdapters.get_all_accessible_agents(user)
+ agents_packet = list()
+ for agent in agents:
+ agents_packet.append(
+ {
+ "slug": agent.slug,
+ "avatar": agent.avatar,
+ "name": agent.name,
+ "personality": agent.personality,
+ "public": agent.public,
+ "creator": agent.creator.username if agent.creator else None,
+ "managed_by_admin": agent.managed_by_admin,
+ }
+ )
+ return templates.TemplateResponse(
+ "agents.html",
+ context={
+ "request": request,
+ "agents": agents_packet,
+ "khoj_version": state.khoj_version,
+ "username": user.username if user else None,
+ "has_documents": False,
+ "is_active": has_required_scope(request, ["premium"]),
+ "user_photo": user_picture,
+ },
+ )
+
+
+@web_client.get("/agent/{agent_slug}", response_class=HTMLResponse)
+def agent_page(request: Request, agent_slug: str):
+ user: KhojUser = request.user.object if request.user.is_authenticated else None
+ user_picture = request.session.get("user", {}).get("picture") if user else None
+
+ agent = AgentAdapters.get_agent_by_slug(agent_slug)
+
+ if agent == None:
+ return templates.TemplateResponse(
+ "404.html",
+ context={
+ "request": request,
+ "khoj_version": state.khoj_version,
+ "username": user.username if user else None,
+ "has_documents": False,
+ "is_active": has_required_scope(request, ["premium"]),
+ "user_photo": user_picture,
+ },
+ )
+
+ agent_metadata = {
+ "slug": agent.slug,
+ "avatar": agent.avatar,
+ "name": agent.name,
+ "personality": agent.personality,
+ "public": agent.public,
+ "creator": agent.creator.username if agent.creator else None,
+ "managed_by_admin": agent.managed_by_admin,
+ "chat_model": agent.chat_model.chat_model,
+ "creator_not_self": agent.creator != user,
+ }
+
+ return templates.TemplateResponse(
+ "agent.html",
+ context={
+ "request": request,
+ "agent": agent_metadata,
+ "khoj_version": state.khoj_version,
+ "username": user.username if user else None,
+ "has_documents": False,
+ "is_active": has_required_scope(request, ["premium"]),
+ "user_photo": user_picture,
},
)
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 150398ee..d713c335 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -15,6 +15,7 @@ from os import path
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union
+from urllib.parse import urlparse
import torch
from asgiref.sync import sync_to_async
@@ -270,6 +271,7 @@ class ConversationCommand(str, Enum):
Notes = "notes"
Help = "help"
Online = "online"
+ Webpage = "webpage"
Image = "image"
@@ -278,15 +280,17 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Search for information on the internet.",
+ ConversationCommand.Webpage: "Get information from webpage links provided by you.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}
tool_descriptions_for_llm = {
ConversationCommand.Default: "To use a mix of your internal knowledge and the user's personal knowledge, or if you don't entirely understand the query.",
- ConversationCommand.General: "Use this when you can answer the question without any outside information or personal knowledge",
+ ConversationCommand.General: "To use when you can answer the question without any outside information or personal knowledge",
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
+ ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
}
mode_descriptions_for_llm = {
@@ -340,3 +344,12 @@ def in_debug_mode():
"""Check if Khoj is running in debug mode.
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
return is_env_var_true("KHOJ_DEBUG")
+
+
+def is_valid_url(url: str) -> bool:
+ """Check if a string is a valid URL"""
+ try:
+ result = urlparse(url.strip())
+ return all([result.scheme, result.netloc])
+ except:
+ return False
diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py
index eea6cafa..4baa4f69 100644
--- a/tests/test_gpt4all_chat_director.py
+++ b/tests/test_gpt4all_chat_director.py
@@ -579,7 +579,7 @@ async def test_get_correct_tools_general(client_offline_chat):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
-async def test_get_correct_tools_with_chat_history(client_offline_chat):
+async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
chat_log = [
@@ -590,7 +590,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat):
),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
]
- chat_history = create_conversation(chat_log)
+ chat_history = create_conversation(chat_log, default_user2)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)
diff --git a/tests/test_helpers.py b/tests/test_helpers.py
index 086e4895..131c3553 100644
--- a/tests/test_helpers.py
+++ b/tests/test_helpers.py
@@ -7,7 +7,10 @@ import pytest
from scipy.stats import linregress
from khoj.processor.embeddings import EmbeddingsModel
-from khoj.processor.tools.online_search import read_webpage, read_webpage_with_olostep
+from khoj.processor.tools.online_search import (
+ read_webpage_at_url,
+ read_webpage_with_olostep,
+)
from khoj.utils import helpers
@@ -90,7 +93,7 @@ async def test_reading_webpage():
website = "https://en.wikipedia.org/wiki/Great_Chicago_Fire"
# Act
- response = await read_webpage(website)
+ response = await read_webpage_at_url(website)
# Assert
assert (
diff --git a/tests/test_markdown_to_entries.py b/tests/test_markdown_to_entries.py
index 4a4a75f3..12ea238e 100644
--- a/tests/test_markdown_to_entries.py
+++ b/tests/test_markdown_to_entries.py
@@ -34,7 +34,9 @@ def test_markdown_file_with_no_headings_to_jsonl(tmp_path):
# Ensure raw entry with no headings do not get heading prefix prepended
assert not jsonl_data[0]["raw"].startswith("#")
# Ensure compiled entry has filename prepended as top level heading
- assert jsonl_data[0]["compiled"].startswith(expected_heading)
+ assert expected_heading in jsonl_data[0]["compiled"]
+ # Ensure compiled entry also includes the file name
+ assert str(tmp_path) in jsonl_data[0]["compiled"]
def test_single_markdown_entry_to_jsonl(tmp_path):
diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py
index b08205de..5c2855b2 100644
--- a/tests/test_openai_chat_actors.py
+++ b/tests/test_openai_chat_actors.py
@@ -11,6 +11,7 @@ from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
generate_online_subqueries,
+ infer_webpage_urls,
)
from khoj.utils.helpers import ConversationCommand
@@ -546,6 +547,34 @@ async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
assert ConversationCommand.Online in conversation_commands
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_select_data_sources_actor_chooses_to_read_webpage(chat_client):
+ # Arrange
+ user_query = "Summarize the wikipedia page on the history of the internet"
+
+ # Act
+ conversation_commands = await aget_relevant_information_sources(user_query, {})
+
+ # Assert
+ assert ConversationCommand.Webpage in conversation_commands
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.anyio
+@pytest.mark.django_db(transaction=True)
+async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
+ # Arrange
+ user_query = "Summarize the wikipedia page on the history of the internet"
+
+ # Act
+ urls = await infer_webpage_urls(user_query, {}, None)
+
+ # Assert
+ assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
+
+
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):