Research Mode [Part 2]: Improve Prompts, Edit Chat Messages. Set LLM Seed for Reproducibility (#954)

- Improve chat actors and their prompts for research mode.
- Add documentation to enable the code tool when self-hosting Khoj
- Edit Chat Messages
  - Store Turn Id in each chat message. 
  - Expose API to delete chat message.
  - Expose delete chat message button to turn delete chat message from web app
- Set LLM Generation Seed for Reproducible Debugging and Testing
  - Setting seed for LLM generation is supported by Llama.cpp and OpenAI models. 
    This can (somewhat) restrain LLM output
  - Getting fixed responses for fixed inputs helps test, debug longer reasoning chains like used in advanced reasoning
This commit is contained in:
Debanjum 2024-11-01 18:16:42 -07:00 committed by GitHub
commit cff8e02b60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 349 additions and 124 deletions

View file

@ -14,6 +14,10 @@ services:
interval: 30s
timeout: 10s
retries: 5
sandbox:
image: ghcr.io/khoj-ai/terrarium:latest
ports:
- "8080:8080"
server:
depends_on:
database:

View file

@ -43,3 +43,6 @@ Slash commands allows you to change what Khoj uses to respond to your query
- **/image**: Generate an image in response to your query.
- **/help**: Use /help to get all available commands and general information about Khoj
- **/summarize**: Can be used to summarize 1 selected file filter for that conversation. Refer to [File Summarization](summarization) for details.
- **/diagram**: Generate a diagram in response to your query. This is built on [Excalidraw](https://excalidraw.com/).
- **/code**: Generate and run very simple Python code snippets. Refer to [Code Generation](code_generation) for details.
- **/research**: Go deeper in a topic for more accurate, in-depth responses.

View file

@ -0,0 +1,30 @@
---
---
# Code Execution
Khoj can generate and run very simple Python code snippets as well. This is useful if you want to generate a plot, run a simple calculation, or do some basic data manipulation. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks.
Just use `/code` in your chat command.
### Setup (Self-Hosting)
Run [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) on your machine to enable code generation and execution.
Check the [instructions](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file#development) for running from source.
For running with Docker, you can use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), or start it manually like this:
```bash
docker pull ghcr.io/khoj-ai/terrarium:latest
docker run -d -p 8080:8080 ghcr.io/khoj-ai/terrarium:latest
```
#### Verify
Verify that it's running, by evaluating a simple Python expression:
```bash
curl -X POST -H "Content-Type: application/json" \
--url http://localhost:8080 \
--data-raw '{"code": "1 + 1"}' \
--no-buffer
```

View file

@ -87,7 +87,7 @@ dependencies = [
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
"docx2txt == 0.8",
"google-generativeai == 0.7.2"
"google-generativeai == 0.8.3"
]
dynamic = ["version"]

View file

@ -29,6 +29,7 @@ interface ChatBodyDataProps {
onConversationIdChange?: (conversationId: string) => void;
setQueryToProcess: (query: string) => void;
streamedMessages: StreamMessage[];
setStreamedMessages: (messages: StreamMessage[]) => void;
setUploadedFiles: (files: string[]) => void;
isMobileWidth?: boolean;
isLoggedIn: boolean;
@ -118,6 +119,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
setAgent={setAgentMetadata}
pendingMessage={processingMessage ? message : ""}
incomingMessages={props.streamedMessages}
setIncomingMessages={props.setStreamedMessages}
customClassName={chatHistoryCustomClassName}
/>
</div>
@ -351,6 +353,7 @@ export default function Chat() {
<ChatBodyData
isLoggedIn={authenticatedData !== null}
streamedMessages={messages}
setStreamedMessages={setMessages}
chatOptionsData={chatOptionsData}
setTitle={setTitle}
setQueryToProcess={setQueryToProcess}

View file

@ -11,6 +11,11 @@ export interface RawReferenceData {
codeContext?: CodeContext;
}
export interface MessageMetadata {
conversationId: string;
turnId: string;
}
export interface ResponseWithIntent {
intentType: string;
response: string;
@ -90,6 +95,9 @@ export function processMessageChunk(
if (references.onlineContext) onlineContext = references.onlineContext;
if (references.codeContext) codeContext = references.codeContext;
return { context, onlineContext, codeContext };
} else if (chunk.type === "metadata") {
const messageMetadata = chunk.data as MessageMetadata;
currentMessage.turnId = messageMetadata.turnId;
} else if (chunk.type === "message") {
const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw

View file

@ -42,6 +42,13 @@ export function converColorToBgGradient(color: string) {
return `${convertToBGGradientClass(color)} dark:border dark:border-neutral-700`;
}
export function convertColorToCaretClass(color: string | undefined) {
if (color && tailwindColors.includes(color)) {
return `caret-${color}-500`;
}
return `caret-orange-500`;
}
export function convertColorToRingClass(color: string | undefined) {
if (color && tailwindColors.includes(color)) {
return `focus-visible:ring-${color}-500`;

View file

@ -34,8 +34,9 @@ interface ChatHistory {
interface ChatHistoryProps {
conversationId: string;
setTitle: (title: string) => void;
incomingMessages?: StreamMessage[];
pendingMessage?: string;
incomingMessages?: StreamMessage[];
setIncomingMessages?: (incomingMessages: StreamMessage[]) => void;
publicConversationSlug?: string;
setAgent: (agent: AgentData) => void;
customClassName?: string;
@ -45,7 +46,7 @@ interface TrainOfThoughtComponentProps {
trainOfThought: string[];
lastMessage: boolean;
agentColor: string;
key: string;
keyId: string;
completed?: boolean;
}
@ -56,7 +57,7 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
return (
<div
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
key={props.key}
key={props.keyId}
>
{!props.completed && <InlineLoading className="float-right" />}
{props.completed &&
@ -97,6 +98,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
const [data, setData] = useState<ChatHistoryData | null>(null);
const [currentPage, setCurrentPage] = useState(0);
const [hasMoreMessages, setHasMoreMessages] = useState(true);
const [currentTurnId, setCurrentTurnId] = useState<string | null>(null);
const sentinelRef = useRef<HTMLDivElement | null>(null);
const scrollAreaRef = useRef<HTMLDivElement | null>(null);
const latestUserMessageRef = useRef<HTMLDivElement | null>(null);
@ -177,6 +179,10 @@ export default function ChatHistory(props: ChatHistoryProps) {
if (lastMessage && !lastMessage.completed) {
setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1);
props.setTitle(lastMessage.rawQuery);
// Store the turnId when we get it
if (lastMessage.turnId) {
setCurrentTurnId(lastMessage.turnId);
}
}
}
}, [props.incomingMessages]);
@ -278,6 +284,25 @@ export default function ChatHistory(props: ChatHistoryProps) {
return data.agent?.persona;
}
const handleDeleteMessage = (turnId?: string) => {
if (!turnId) return;
setData((prevData) => {
if (!prevData || !turnId) return prevData;
return {
...prevData,
chat: prevData.chat.filter((msg) => msg.turnId !== turnId),
};
});
// Update incoming messages if they exist
if (props.incomingMessages && props.setIncomingMessages) {
props.setIncomingMessages(
props.incomingMessages.filter((msg) => msg.turnId !== turnId),
);
}
};
if (!props.conversationId && !props.publicConversationSlug) {
return null;
}
@ -293,6 +318,18 @@ export default function ChatHistory(props: ChatHistoryProps) {
data.chat &&
data.chat.map((chatMessage, index) => (
<>
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
<TrainOfThoughtComponent
trainOfThought={chatMessage.trainOfThought?.map(
(train) => train.data,
)}
lastMessage={false}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
keyId={`${index}trainOfThought`}
completed={true}
/>
)}
<ChatMessage
key={`${index}fullHistory`}
ref={
@ -312,22 +349,14 @@ export default function ChatHistory(props: ChatHistoryProps) {
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
onDeleteMessage={handleDeleteMessage}
conversationId={props.conversationId}
/>
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
<TrainOfThoughtComponent
trainOfThought={chatMessage.trainOfThought?.map(
(train) => train.data,
)}
lastMessage={false}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
completed={true}
/>
)}
</>
))}
{props.incomingMessages &&
props.incomingMessages.map((message, index) => {
const messageTurnId = message.turnId ?? currentTurnId ?? undefined;
return (
<React.Fragment key={`incomingMessage${index}`}>
<ChatMessage
@ -342,9 +371,14 @@ export default function ChatHistory(props: ChatHistoryProps) {
by: "you",
automationId: "",
images: message.images,
conversationId: props.conversationId,
turnId: messageTurnId,
}}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
onDeleteMessage={handleDeleteMessage}
conversationId={props.conversationId}
turnId={messageTurnId}
/>
{message.trainOfThought && (
<TrainOfThoughtComponent
@ -352,6 +386,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
lastMessage={index === incompleteIncomingMessageIndex}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
keyId={`${index}trainOfThought`}
completed={message.completed}
/>
)}
@ -373,7 +408,12 @@ export default function ChatHistory(props: ChatHistoryProps) {
"memory-type": "",
"inferred-queries": message.inferredQueries || [],
},
conversationId: props.conversationId,
turnId: messageTurnId,
}}
conversationId={props.conversationId}
turnId={messageTurnId}
onDeleteMessage={handleDeleteMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={true}
@ -393,7 +433,11 @@ export default function ChatHistory(props: ChatHistoryProps) {
created: new Date().getTime().toString(),
by: "you",
automationId: "",
conversationId: props.conversationId,
turnId: undefined,
}}
conversationId={props.conversationId}
onDeleteMessage={handleDeleteMessage}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={true}

View file

@ -149,7 +149,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
}
let messageToSend = message.trim();
if (useResearchMode) {
if (useResearchMode && !messageToSend.startsWith("/research")) {
messageToSend = `/research ${messageToSend}`;
}
@ -398,7 +398,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<PopoverContent
onOpenAutoFocus={(e) => e.preventDefault()}
className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`}
side="top"
side="bottom"
align="center"
/* Offset below text area on home page (i.e where conversationId is unset) */
sideOffset={props.conversationId ? 0 : 80}
@ -590,8 +590,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</Button>
</TooltipTrigger>
<TooltipContent className="text-xs">
Research Mode allows you to get more deeply researched, detailed
responses. Response times may be longer.
(Experimental) Research Mode allows you to get more deeply researched,
detailed responses. Response times may be longer.
</TooltipContent>
</Tooltip>
</TooltipProvider>

View file

@ -29,6 +29,7 @@ import {
Check,
Code,
Shapes,
Trash,
} from "@phosphor-icons/react";
import DOMPurify from "dompurify";
@ -146,6 +147,8 @@ export interface SingleChatMessage {
intent?: Intent;
agent?: AgentData;
images?: string[];
conversationId: string;
turnId?: string;
}
export interface StreamMessage {
@ -161,6 +164,7 @@ export interface StreamMessage {
images?: string[];
intentType?: string;
inferredQueries?: string[];
turnId?: string;
}
export interface ChatHistoryData {
@ -242,6 +246,9 @@ interface ChatMessageProps {
borderLeftColor?: string;
isLastMessage?: boolean;
agent?: AgentData;
onDeleteMessage: (turnId?: string) => void;
conversationId: string;
turnId?: string;
}
interface TrainOfThoughtProps {
@ -654,6 +661,27 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
});
}
const deleteMessage = async (message: SingleChatMessage) => {
const turnId = message.turnId || props.turnId;
const response = await fetch("/api/chat/conversation/message", {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
conversation_id: props.conversationId,
turn_id: turnId,
}),
});
if (response.ok) {
// Update the UI after successful deletion
props.onDeleteMessage(turnId);
} else {
console.error("Failed to delete message");
}
};
const allReferences = constructAllReferences(
props.chatMessage.context,
props.chatMessage.onlineContext,
@ -716,6 +744,18 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
/>
</button>
))}
{props.chatMessage.turnId && (
<button
title="Delete"
className={`${styles.deleteButton}`}
onClick={() => deleteMessage(props.chatMessage)}
>
<Trash
alt="Delete Message"
className="hsl(var(--muted-foreground)) hover:text-red-500"
/>
</button>
)}
<button
title="Copy"
className={`${styles.copyButton}`}

View file

@ -195,8 +195,12 @@ function ReferenceVerification(props: ReferenceVerificationProps) {
created: new Date().toISOString(),
onlineContext: {},
codeContext: {},
conversationId: props.conversationId,
turnId: "",
}}
isMobileWidth={isMobileWidth}
onDeleteMessage={(turnId?: string) => {}}
conversationId={props.conversationId}
/>
</div>
);
@ -626,7 +630,11 @@ export default function FactChecker() {
created: new Date().toISOString(),
onlineContext: {},
codeContext: {},
conversationId: conversationID,
turnId: "",
}}
conversationId={conversationID}
onDeleteMessage={(turnId?: string) => {}}
isMobileWidth={isMobileWidth}
/>
</div>

View file

@ -32,6 +32,11 @@ const config = {
/ring-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["focus-visible", "dark"],
},
{
pattern:
/caret-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["focus", "dark"],
},
],
darkMode: ["class"],
content: [

View file

@ -262,7 +262,7 @@ def configure_server(
initialize_content(regenerate, search_type, user)
except Exception as e:
raise e
logger.error(f"Failed to load some search models: {e}", exc_info=True)
def setup_default_agent(user: KhojUser):

View file

@ -476,9 +476,8 @@ def get_default_search_model() -> SearchModelConfig:
if default_search_model:
return default_search_model
else:
elif SearchModelConfig.objects.count() == 0:
SearchModelConfig.objects.create()
return SearchModelConfig.objects.first()
@ -1319,6 +1318,8 @@ class ConversationAdapters:
def add_files_to_filter(user: KhojUser, conversation_id: str, files: List[str]):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
file_list = EntryAdapters.get_all_filenames_by_source(user, "computer")
if not conversation:
return []
for filename in files:
if filename in file_list and filename not in conversation.file_filters:
conversation.file_filters.append(filename)
@ -1332,6 +1333,8 @@ class ConversationAdapters:
@staticmethod
def remove_files_from_filter(user: KhojUser, conversation_id: str, files: List[str]):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation:
return []
for filename in files:
if filename in conversation.file_filters:
conversation.file_filters.remove(filename)
@ -1343,6 +1346,17 @@ class ConversationAdapters:
conversation.save()
return conversation.file_filters
@staticmethod
def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str):
conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id)
if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"):
return False
conversation_log = conversation.conversation_log
updated_log = [msg for msg in conversation_log["chat"] if msg.get("turnId") != turn_id]
conversation.conversation_log["chat"] = updated_log
conversation.save()
return True
class FileObjectAdapters:
@staticmethod

View file

@ -1,5 +1,6 @@
import json
import logging
import os
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Optional, Union
@ -263,8 +264,14 @@ def send_message_to_model_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
messages_dict,
stop=stop,
stream=streaming,
temperature=temperature,
response_format={"type": response_type},
seed=seed,
)
if streaming:

View file

@ -1,4 +1,5 @@
import logging
import os
from threading import Thread
from typing import Dict
@ -60,6 +61,9 @@ def completion_with_backoff(
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
messages=formatted_messages, # type: ignore
@ -157,6 +161,9 @@ def llm_thread(
model_kwargs.pop("stop", None)
model_kwargs.pop("response_format", None)
if os.getenv("KHOJ_LLM_SEED"):
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
chat = client.chat.completions.create(
stream=stream,
messages=formatted_messages,

View file

@ -625,25 +625,25 @@ Create a multi-step plan and intelligently iterate on the plan based on the retr
{personality_context}
# Instructions
- Ask detailed queries to the tool AIs provided below, one at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
- Break down your research process into independent, self-contained steps that can be executed sequentially to answer the user's query. Write your step-by-step plan in the scratchpad.
- Ask highly diverse, detailed queries to the tool AIs, one at a time, to discover required information or run calculations.
- NEVER repeat the same query across iterations.
- Ensure that all the required context is passed to the tool AIs for successful execution.
- Ensure that you go deeper when possible and try more broad, creative strategies when a path is not yielding useful results. Build on the results of the previous iterations.
- Ask highly diverse, detailed queries to the tool AIs, one tool AI at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
- Break down your research process into independent, self-contained steps that can be executed sequentially using the available tool AIs to answer the user's query. Write your step-by-step plan in the scratchpad.
- Always ask a new query that was not asked to the tool AI in a previous iteration. Build on the results of the previous iterations.
- Ensure that all required context is passed to the tool AIs for successful execution. They only know the context provided in your query.
- Think step by step to come up with creative strategies when the previous iteration did not yield useful results.
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
# Examples
Assuming you can search the user's notes and the internet.
- When they ask for the population of their hometown
- When the user asks for the population of their hometown
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
- When user for their computer's specs
- When the user asks for their computer's specs
1. Try find their computer model in their notes.
2. Now find webpages with their computer model's spec online and read them.
- When I ask what clothes to carry for their upcoming trip
2. Now find webpages with their computer model's spec online.
3. Ask the the webpage tool AI to extract the required information from the relevant webpages.
- When the user asks what clothes to carry for their upcoming trip
1. Find the itinerary of their upcoming trip in their notes.
2. Next find the weather forecast at the destination online.
3. Then find if they mentioned what clothes they own in their notes.
@ -666,7 +666,7 @@ Which of the tool AIs listed below would you use to answer the user's question?
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "query": "<your_detailed_query_for_the_tool_ai>", "tool": "<name_of_tool_ai>"}}
""".strip()
)
@ -798,8 +798,8 @@ Khoj:
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
- You will receive the conversation history as context.
- Add as much context from the previous questions and answers as required into your search queries.
- You will receive the actual chat history as context.
- Add as much context from the chat history as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information.
@ -812,58 +812,56 @@ User's Location: {location}
{username}
Here are some examples:
History:
Example Chat History:
User: I like to use Hacker News to get my tech news.
Khoj: {{queries: ["what is Hacker News?", "Hacker News website for tech news"]}}
AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups.
Q: Summarize the top posts on HackerNews
User: Summarize the top posts on HackerNews
Khoj: {{"queries": ["top posts on HackerNews"]}}
History:
Q: Tell me the latest news about the farmers protest in Colombia and China on Reuters
Example Chat History:
User: Tell me the latest news about the farmers protest in Colombia and China on Reuters
Khoj: {{"queries": ["site:reuters.com farmers protest Colombia", "site:reuters.com farmers protest China"]}}
History:
Example Chat History:
User: I'm currently living in New York but I'm thinking about moving to San Francisco.
Khoj: {{"queries": ["New York city vs San Francisco life", "San Francisco living cost", "New York city living cost"]}}
AI: New York is a great city to live in. It has a lot of great restaurants and museums. San Francisco is also a great city to live in. It has good access to nature and a great tech scene.
Q: What is the climate like in those cities?
Khoj: {{"queries": ["climate in new york city", "climate in san francisco"]}}
User: What is the climate like in those cities?
Khoj: {{"queries": ["climate in New York city", "climate in San Francisco"]}}
History:
AI: Hey, how is it going?
User: Going well. Ananya is in town tonight!
Example Chat History:
User: Hey, Ananya is in town tonight!
Khoj: {{"queries": ["events in {location} tonight", "best restaurants in {location}", "places to visit in {location}"]}}
AI: Oh that's awesome! What are your plans for the evening?
Q: She wants to see a movie. Any decent sci-fi movies playing at the local theater?
User: She wants to see a movie. Any decent sci-fi movies playing at the local theater?
Khoj: {{"queries": ["new sci-fi movies in theaters near {location}"]}}
History:
Example Chat History:
User: Can I chat with you over WhatsApp?
Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}}
AI: Yes, you can chat with me using WhatsApp.
Q: How
Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}}
History:
Q: How do I share my files with you?
Example Chat History:
User: How do I share my files with Khoj?
Khoj: {{"queries": ["site:khoj.dev sync files with Khoj"]}}
History:
Example Chat History:
User: I need to transport a lot of oranges to the moon. Are there any rockets that can fit a lot of oranges?
Khoj: {{"queries": ["current rockets with large cargo capacity", "rocket rideshare cost by cargo capacity"]}}
AI: NASA's Saturn V rocket frequently makes lunar trips and has a large cargo capacity.
Q: How many oranges would fit in NASA's Saturn V rocket?
Khoj: {{"queries": ["volume of an orange", "volume of saturn v rocket"]}}
User: How many oranges would fit in NASA's Saturn V rocket?
Khoj: {{"queries": ["volume of an orange", "volume of Saturn V rocket"]}}
Now it's your turn to construct Google search queries to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
History:
Actual Chat History:
{chat_history}
Q: {query}
User: {query}
Khoj:
""".strip()
)

View file

@ -1,9 +1,11 @@
import base64
import json
import logging
import math
import mimetypes
import os
import queue
import uuid
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
@ -134,7 +136,11 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['message']}\n"
if chat["intent"].get("inferred-queries"):
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
@ -185,6 +191,7 @@ class ChatEvent(Enum):
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"
METADATA = "metadata"
def message_to_log(
@ -232,12 +239,14 @@ def save_to_conversation_log(
train_of_thought: List[Any] = [],
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={
"created": user_message_time,
"images": query_images,
"turnId": turn_id,
},
khoj_message_metadata={
"context": compiled_references,
@ -246,6 +255,7 @@ def save_to_conversation_log(
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
"turnId": turn_id,
},
conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
@ -501,15 +511,12 @@ def commit_conversation_trace(
Returns the path to the repository.
"""
# Serialize session, system message and response to yaml
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False)
query = (
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
.strip()
.removeprefix("'")
.removesuffix("'")
json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session

View file

@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device, merge_dicts, timer
from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@ -31,9 +31,9 @@ class EmbeddingsModel:
):
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()})
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key

View file

@ -54,6 +54,7 @@ OLOSTEP_QUERY_PARAMS = {
}
DEFAULT_MAX_WEBPAGES_TO_READ = 1
MAX_WEBPAGES_TO_INFER = 10
async def search_online(
@ -157,13 +158,16 @@ async def read_webpages(
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
urls = await infer_webpage_urls(
query, conversation_history, location, user, query_images, agent=agent, tracer=tracer
)
# Get the top 10 web pages to read
urls = urls[:max_webpages_to_read]
logger.info(f"Reading web pages at: {urls}")
if send_status_func:

View file

@ -31,6 +31,7 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
from khoj.routers.email import send_query_feedback
from khoj.routers.helpers import (
ApiImageRateLimiter,
ApiUserRateLimiter,
@ -38,13 +39,14 @@ from khoj.routers.helpers import (
ChatRequestBody,
CommonQueryParams,
ConversationCommandRateLimiter,
DeleteMessageRequestBody,
FeedbackData,
agenerate_chat_response,
aget_relevant_information_sources,
aget_relevant_output_modes,
construct_automation_created_message,
create_automation,
extract_relevant_info,
extract_relevant_summary,
generate_excalidraw_diagram,
generate_summary_from_files,
get_conversation_command,
@ -75,16 +77,12 @@ from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, Location
# Initialize Router
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=100, subscribed_rate_limit=6000, slug="command"
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
)
api_chat = APIRouter()
from pydantic import BaseModel
from khoj.routers.email import send_query_feedback
@api_chat.get("/conversation/file-filters/{conversation_id}", response_class=Response)
@requires(["authenticated"])
@ -146,12 +144,6 @@ def remove_file_filter(request: Request, filter: FileFilterRequest) -> Response:
return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200)
class FeedbackData(BaseModel):
uquery: str
kquery: str
sentiment: str
@api_chat.post("/feedback")
@requires(["authenticated"])
async def sendfeedback(request: Request, data: FeedbackData):
@ -166,10 +158,10 @@ async def text_to_speech(
common: CommonQueryParams,
text: str,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
ApiUserRateLimiter(requests=30, subscribed_requests=30, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=50, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
@ -534,6 +526,19 @@ async def set_conversation_title(
)
@api_chat.delete("/conversation/message", response_class=Response)
@requires(["authenticated"])
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
user = request.user.object
success = ConversationAdapters.delete_message_by_turn_id(
user, delete_request.conversation_id, delete_request.turn_id
)
if success:
return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)
else:
return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404)
@api_chat.post("")
@requires(["authenticated"])
async def chat(
@ -541,10 +546,10 @@ async def chat(
common: CommonQueryParams,
body: ChatRequestBody,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute")
ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
):
@ -555,6 +560,7 @@ async def chat(
stream = body.stream
title = body.title
conversation_id = body.conversation_id
turn_id = str(body.turn_id or uuid.uuid4())
city = body.city
region = body.region
country = body.country or get_country_name_from_timezone(body.timezone)
@ -574,7 +580,7 @@ async def chat(
nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"mid": turn_id,
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
@ -607,7 +613,7 @@ async def chat(
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == ChatEvent.REFERENCES or stream:
elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream:
yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False)
except asyncio.CancelledError as e:
connection_alive = False
@ -651,6 +657,11 @@ async def chat(
metadata=chat_metadata,
)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
conversation_commands = [get_conversation_command(query=q, any_references=True)]
conversation = await ConversationAdapters.aget_conversation_by_user(
@ -666,6 +677,9 @@ async def chat(
return
conversation_id = conversation.id
async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}):
yield event
agent: Agent | None = None
default_agent = await AgentAdapters.aget_default_agent()
if conversation.agent and conversation.agent != default_agent:
@ -677,17 +691,11 @@ async def chat(
agent = default_agent
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
location = None
if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code)
if is_query_empty(q):
async for result in send_llm_response("Please ask your query to get started."):
yield result
return
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
meta_log = conversation.conversation_log
@ -699,7 +707,6 @@ async def chat(
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
defiltered_query = defilter_query(q)
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
@ -730,6 +737,12 @@ async def chat(
if mode not in conversation_commands:
conversation_commands.append(mode)
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
defiltered_query = defilter_query(q)
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
request=request,

View file

@ -478,6 +478,9 @@ async def infer_webpage_urls(
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls):
raise ValueError(f"Invalid list of urls: {response}")
if len(valid_unique_urls) == 0:
logger.error(f"No valid URLs found in response: {response}")
return []
return list(valid_unique_urls)
except Exception:
raise ValueError(f"Invalid list of urls: {response}")
@ -1255,6 +1258,7 @@ class ChatRequestBody(BaseModel):
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
turn_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
@ -1264,6 +1268,17 @@ class ChatRequestBody(BaseModel):
create_new: Optional[bool] = False
class DeleteMessageRequestBody(BaseModel):
conversation_id: str
turn_id: str
class FeedbackData(BaseModel):
uquery: str
kquery: str
sentiment: str
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
@ -1366,7 +1381,7 @@ class ConversationCommandRateLimiter:
self.slug = slug
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
self.restricted_commands = [ConversationCommand.Research]
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:

View file

@ -1,12 +1,11 @@
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from typing import Callable, Dict, List, Optional
import yaml
from fastapi import Request
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
@ -191,18 +190,18 @@ async def execute_information_collection(
document_results = result[0]
this_iteration.context += document_results
if not is_none_or_empty(document_results):
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
logger.error(f"Error extracting document references: {e}", exc_info=True)
if not is_none_or_empty(document_results):
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online:
async for result in search_online(
@ -306,13 +305,13 @@ async def execute_information_collection(
if document_results or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if document_results:
results_data += f"**Document References**: {yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if online_results:
results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if code_results:
results_data += f"**Code Results**: {yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if summarize_files:
results_data += f"**Summarized Files**: {yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data

View file

@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict):
return merged_dict
def fix_json_dict(json_dict: dict) -> dict:
for k, v in json_dict.items():
if v == "True" or v == "False":
json_dict[k] = v == "True"
if isinstance(v, dict):
json_dict[k] = fix_json_dict(v)
return json_dict
def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
"Get file type from file mime type"
@ -359,9 +368,9 @@ tool_descriptions_for_llm = {
function_calling_description_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search the internet for information. Provide all relevant context to ensure new searches, not previously run, are performed.",
ConversationCommand.Webpage: "To extract information from a webpage. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage link and information to extract in your query.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.",
ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
}
mode_descriptions_for_llm = {