Merge branch 'master' of github.com:khoj-ai/khoj into features/chat-ui-updates-big

This commit is contained in:
sabaimran 2024-07-08 17:00:42 +05:30
commit bf4c2f219e
60 changed files with 1419 additions and 492 deletions

View file

@ -15,6 +15,7 @@ Khoj will keep these files in sync to provide contextual responses when you sear
- **Faster answers**: Find answers quickly, from your private notes or the public internet
- **Assisted creativity**: Smoothly weave across retrieving answers and generating content
- **Iterative discovery**: Iteratively explore and re-discover your notes
- **Quick access**: Use [Khoj Mini](/features/khoj_mini) on the desktop to quickly pull up a mini chat module for quicker answers
- **Search**
- **Natural**: Advanced natural language understanding using Transformer based ML Models
- **Incremental**: Incremental search for a fast, search-as-you-type experience

View file

@ -0,0 +1,9 @@
# Desktop Quick Chat (Khoj Mini)
Once you have the Khoj [desktop application](https://khoj.dev/downloads) installed, you can use the desktop shortcut to quickly pull up a mini chat module for quicker answers. See the desktop setup instructions [in the docs](/clients/desktop.md) for more information.
To use it, you just have to copy the text you want to inject into your query, and then run `Ctrl + Shift + K` (or `Cmd + Shift + K` on Mac) to open the mini chat module. The text you copied will be automatically pasted into the chat module, and you can then hit enter to get the answer. You can edit the text before hitting enter if you want to refine your query.
The desktop shortcut is a great way to quickly get answers to your questions without having to switch between windows or tabs. It's especially useful when you're working on a project and need to quickly look up something without losing your focus.
![Desktop Shortcut](https://assets.khoj.dev/courseload_decision_dekstop.gif)

View file

@ -1,17 +1,21 @@
# Online Search
By default, Khoj will try to infer which information-sourcing tools are required to answer your question. Sometimes, you'll have a need for outside questions that the LLM's knowledge doesn't cover. In that case, it will use the `online` search feature.
Khoj will research on the internet to ground its responses, when it determines that it would need fresh information outside its existing knowledge to answer the query. It will always show any online references it used to respond to your requests.
For example, these queries would trigger an online search:
By default, Khoj will try to infer which information sources, it needs to read to answer your question. This can include reading your documents or researching information online. You can also explicitly trigger an online search by adding the `/online` prefix to your chat query.
Example queries that should trigger an online search:
- What's the latest news about the Israel-Palestine war?
- Where can I find the best pizza in New York City?
- Deadline for filing taxes 2024.
- /online Deadline for filing taxes 2024.
- Give me a summary of this article: https://en.wikipedia.org/wiki/Haitian_Revolution
Try it out yourself! https://app.khoj.dev
## Self-Hosting
The general online search function currently requires an API key from Serper.dev. You can grab one here: https://serper.dev/, and then add it as an environment variable with the name `SERPER_DEV_API_KEY`.
Online search works out of the box even when self-hosting. Khoj uses [JinaAI's reader API](https://jina.ai/reader/) to search online and read webpages by default. No API key setup is necessary.
Without any API keys, Khoj will use the `requests` library to directly read any webpages you give it a link to. This means that you can use Khoj to read any webpage that you have access in your local network.
To improve online search, set the `SERPER_DEV_API_KEY` environment variable to your [Serper.dev](https://serper.dev/) API key. These search results include additional context like answer box, knowledge graph etc.
For advanced webpage reading, set the `OLOSTEP_API_KEY` environment variable to your [Olostep](https://www.olostep.com/) API key. This has a higher success rate at reading webpages than the default webpage reader.

View file

@ -1,51 +0,0 @@
---
sidebar_position: 2
---
# Demos
Check out a couple of demos and screenshots of Khoj in action.
### Screenshots
| Web | Obsidian | Emacs |
|:---:|:--------:|:-----:|
| ![](/img/khoj_search_on_web.png ':size=300px') | ![](/img/khoj_search_on_obsidian.png ':size=300px') | ![](/img/khoj_search_on_emacs.png ':size=300px') |
| ![](/img/khoj_chat_on_web.png ':size=300px') | ![](/img/khoj_chat_on_obsidian.png ':size=300px') | ![](/img/khoj_chat_on_emacs.png ':size=400px') |
### Videos
#### Khoj in Obsidian
[Link to Video](https://github-production-user-asset-6210df.s3.amazonaws.com/6413477/240061700-3e33d8ea-25bb-46c8-a3bf-c92f78d0f56b.mp4)
##### Installation
1. Install Khoj via `pip` and start Khoj backend in a terminal (Run `khoj`)
```bash
python -m pip install khoj-assistant
khoj
```
2. Install Khoj plugin via Community Plugins settings pane on Obsidian app
- Check the new Khoj plugin settings
- Let Khoj backend index the markdown, pdf, Github markdown files in the current Vault
- Open Khoj plugin on Obsidian via Search button on Left Pane
- Search \"*Announce plugin to folks*\" in the [Obsidian Plugin docs](https://marcus.se.net/obsidian-plugin-docs/)
- Jump to the [search result](https://marcus.se.net/obsidian-plugin-docs/publishing/submit-your-plugin)
#### Khoj in Emacs, Browser
[Link to Video](https://user-images.githubusercontent.com/6413477/184735169-92c78bf1-d827-4663-9087-a1ea194b8f4b.mp4)
##### Installation
- Install Khoj via pip
- Start Khoj app
- Add this readme and [khoj.el readme](https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs) as org-mode for Khoj to index
- Search \"*Setup editor*\" on the Web and Emacs. Re-rank the results for better accuracy
- Top result is what we are looking for, the [section to Install Khoj.el on Emacs](https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs#2-Install-Khojel)
##### Analysis
- The results do not have any words used in the query
- *Based on the top result it seems the re-ranking model understands that Emacs is an editor?*
- The results incrementally update as the query is entered
- The results are re-ranked, for better accuracy, once user hits enter

View file

@ -27,10 +27,10 @@ keywords: ["khoj", "khoj ai", "khoj docs", "khoj documentation", "khoj features"
Welcome to the Khoj Docs! This is the best place to get setup and explore Khoj's features.
- Khoj is an open source, personal AI
- You can [chat](/features/chat) with it about anything. It'll use files you shared with it to respond, when relevant
- You can [chat](/features/chat) with it about anything. It'll use files you shared with it to respond, when relevant. It can also access information from the public internet.
- Quickly [find](/features/search) relevant notes and documents using natural language
- It understands pdf, plaintext, markdown, org-mode files, [notion pages](/data-sources/notion_integration) and [github repositories](/data-sources/github_integration)
- Access it from your [Emacs](/clients/emacs), [Obsidian](/clients/obsidian), [Web browser](/clients/web) or the [Khoj Desktop app](/clients/desktop)
- Access it from your [Emacs](/clients/emacs), [Obsidian](/clients/obsidian), the [Khoj desktop app](/clients/desktop), or [any web browser](/clients/web)
- Use [cloud](https://app.khoj.dev/login) to access your Khoj anytime from anywhere, [self-host](/get-started/setup) on consumer hardware for privacy
## Quickstart
@ -39,13 +39,3 @@ Welcome to the Khoj Docs! This is the best place to get setup and explore Khoj's
## At a Glance
![demo_chat](https://assets.khoj.dev/using_khoj_for_studying.gif)
#### [Search](/features/search)
- **Natural**: Use natural language queries to quickly find relevant notes and documents.
- **Incremental**: Incremental search for a fast, search-as-you-type experience
#### [Chat](/features/chat)
- **Faster answers**: Find answers faster, smoother than search. No need to manually scan through your notes to find answers.
- **Iterative discovery**: Iteratively explore and (re-)discover your notes
- **Assisted creativity**: Smoothly weave across answers retrieval and content generation
- **Online or Offline**: Choose online or offline chat depending on your requirements

View file

@ -22,6 +22,8 @@ Self-hosting isn't for everyone, so we've still taken steps to make Khoj privacy
1. Your embeddings and the associated raw text are stored in a secure Postgres DB in our private AWS cloud. Your data is sharded on a unique user ID. We store the raw text in your files to improve file syncing and provide context when you chat with Khoj.
1. When you use the single-sign-on option with Google, we only receive your name, a link to your profile photo, and your email address.
You can see our full privacy policy [here](https://khoj.dev/privacy-policy).
:::tip[Info]
Your data is yours. We do not sell your data or use it for training models. Khoj is a sustainable, open-source alternative to closed-source, commercial personal AI. We have no interest in selling your data to make a quick buck.

View file

@ -210,7 +210,7 @@ Add a `ServerChatSettings` with `Default` and `Summarizer` fields set to your pr
##### Configure OpenAI Chat
:::info[Ollama Integration]
Using Ollama? See the [Ollama Integration](/advanced/use-openai-proxy#ollama) section for more custom setup instructions.
Using Ollama? See the [Ollama Integration](/advanced/ollama) section for more custom setup instructions.
:::
1. Go to the [OpenAI settings](http://localhost:42110/server/admin/database/openaiprocessorconversationconfig/) in the server admin settings to add an OpenAI processor conversation config. This is where you set your API key and server API base URL. The API base URL is optional - it's only relevant if you're using another OpenAI-compatible proxy server.
@ -227,11 +227,9 @@ Any chat model on Huggingface in GGUF format can be used for local chat. Here's
- The `tokenizer` and `max-prompt-size` fields are optional. You can set these for non-standard models (i.e not Mistral or Llama based models) or when you know the token limit of the model to improve context stuffing.
#### Share your data
You can sync your files and folders with Khoj using the [Desktop](/get-started/setup#2-download-the-desktop-client), Obsidian, or Emacs clients or just drag and drop specific files on the Web client Here's how you can do it:
1. Select files and folders to index [using the desktop client]. When you click 'Save', the files will be sent to your server for indexing.
- Select Notion workspaces and Github repositories to index using the web interface.
You can sync your files and folders with Khoj using the [Desktop](/clients/desktop#setup), [Obsidian](/clients/obsidian#setup), or [Emacs](/clients/emacs#setup) clients or just drag and drop specific files on the [website](/clients/web#upload-documents). You can also directly sync your [Notion workspace](/data-sources/notion_integration).
[^1]: Khoj, by default, can use [OpenAI GPT3.5+ chat models](https://platform.openai.com/docs/models/overview) or [GGUF chat models](https://huggingface.co/models?library=gguf). See [this section](/miscellaneous/advanced#use-openai-compatible-llm-api-server-self-hosting) on how to locally use OpenAI-format compatible proxy servers.
[^1]: Khoj, by default, can use [OpenAI GPT3.5+ chat models](https://platform.openai.com/docs/models/overview) or [GGUF chat models](https://huggingface.co/models?library=gguf). See [this section](/advanced/use-openai-proxy) on how to locally use OpenAI-format compatible proxy servers.
### 3. Use Khoj 🚀

View file

@ -1,7 +1,7 @@
{
"id": "khoj",
"name": "Khoj",
"version": "1.15.0",
"version": "1.16.0",
"minAppVersion": "0.15.0",
"description": "An AI copilot for your Second Brain",
"author": "Khoj Inc.",

View file

@ -52,7 +52,8 @@ dependencies = [
"pyyaml ~= 6.0",
"rich >= 13.3.1",
"schedule == 1.1.0",
"sentence-transformers == 2.5.1",
"sentence-transformers == 3.0.1",
"einops == 0.8.0",
"transformers >= 4.28.0",
"torch == 2.2.2",
"uvicorn == 0.17.6",

View file

@ -18,6 +18,7 @@ do
# Bump Obsidian plugin to current version
cd $project_root/src/interface/obsidian
yarn build # verify build before bumping version
yarn version --$version_type --no-git-tag-version
# append current version, min Obsidian app version from manifest to versions json
cp $project_root/versions.json .

View file

@ -19,7 +19,7 @@ const textFileTypes = [
'org', 'md', 'markdown', 'txt', 'html', 'xml',
// Other valid text file extensions from https://google.github.io/magika/model/config.json
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
const binaryFileTypes = ['pdf']
const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png']
const validFileTypes = textFileTypes.concat(binaryFileTypes);
const schema = {

View file

@ -1,6 +1,6 @@
{
"name": "Khoj",
"version": "1.15.0",
"version": "1.16.0",
"description": "An AI copilot for your Second Brain",
"author": "Saba Imran, Debanjum Singh Solanky <team@khoj.dev>",
"license": "GPL-3.0-or-later",

View file

@ -6,7 +6,7 @@
;; Saba Imran <saba@khoj.dev>
;; Description: An AI copilot for your Second Brain
;; Keywords: search, chat, org-mode, outlines, markdown, pdf, image
;; Version: 1.15.0
;; Version: 1.16.0
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs

View file

@ -1,7 +1,7 @@
{
"id": "khoj",
"name": "Khoj",
"version": "1.15.0",
"version": "1.16.0",
"minAppVersion": "0.15.0",
"description": "An AI copilot for your Second Brain",
"author": "Khoj Inc.",

View file

@ -1,6 +1,6 @@
{
"name": "Khoj",
"version": "1.15.0",
"version": "1.16.0",
"description": "An AI copilot for your Second Brain",
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
"license": "GPL-3.0-or-later",

View file

@ -1,8 +1,9 @@
import { ItemView, MarkdownRenderer, WorkspaceLeaf, request, requestUrl, setIcon } from 'obsidian';
import { ItemView, MarkdownRenderer, Scope, WorkspaceLeaf, request, requestUrl, setIcon } from 'obsidian';
import * as DOMPurify from 'dompurify';
import { KhojSetting } from 'src/settings';
import { KhojPaneView } from 'src/pane_view';
import { KhojView, createCopyParentText, getLinkToEntry, pasteTextAtCursor } from 'src/utils';
import { KhojSearchModal } from './search_modal';
export interface ChatJsonResult {
image?: string;
@ -24,10 +25,18 @@ export class KhojChatView extends KhojPaneView {
setting: KhojSetting;
waitingForLocation: boolean;
location: Location;
keyPressTimeout: NodeJS.Timeout | null = null;
constructor(leaf: WorkspaceLeaf, setting: KhojSetting) {
super(leaf, setting);
// Register chat view keybindings
this.scope = new Scope(this.app.scope);
this.scope.register(["Ctrl"], 'n', (_) => this.createNewConversation());
this.scope.register(["Ctrl"], 'o', async (_) => await this.toggleChatSessions());
this.scope.register(["Ctrl"], 'f', (_) => new KhojSearchModal(this.app, this.setting).open());
this.scope.register(["Ctrl"], 'r', (_) => new KhojSearchModal(this.app, this.setting, true).open());
this.waitingForLocation = true;
fetch("https://ipapi.co/json")
@ -61,8 +70,7 @@ export class KhojChatView extends KhojPaneView {
return "message-circle";
}
async chat() {
async chat(isVoice: boolean = false) {
// Get text in chat input element
let input_el = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
@ -72,7 +80,7 @@ export class KhojChatView extends KhojPaneView {
this.autoResize();
// Get and render chat response to user message
await this.getChatResponse(user_message);
await this.getChatResponse(user_message, isVoice);
}
async onOpen() {
@ -92,8 +100,9 @@ export class KhojChatView extends KhojPaneView {
const objectSrc = `object-src 'none';`;
const csp = `${defaultSrc} ${scriptSrc} ${connectSrc} ${styleSrc} ${imgSrc} ${childSrc} ${objectSrc}`;
// Add CSP meta tag to the Khoj Chat modal
document.head.createEl("meta", { attr: { "http-equiv": "Content-Security-Policy", "content": `${csp}` } });
// WARNING: CSP DISABLED for now as it breaks other Obsidian plugins. Enable when can scope CSP to only Khoj plugin.
// CSP meta tag for the Khoj Chat modal
// document.head.createEl("meta", { attr: { "http-equiv": "Content-Security-Policy", "content": `${csp}` } });
// Create area for chat logs
let chatBodyEl = contentEl.createDiv({ attr: { id: "khoj-chat-body", class: "khoj-chat-body" } });
@ -104,9 +113,10 @@ export class KhojChatView extends KhojPaneView {
text: "Chat Sessions",
attr: {
class: "khoj-input-row-button clickable-icon",
title: "Show Conversations (^O)",
},
})
chatSessions.addEventListener('click', async (_) => { await this.toggleChatSessions(chatBodyEl) });
chatSessions.addEventListener('click', async (_) => { await this.toggleChatSessions() });
setIcon(chatSessions, "history");
let chatInput = inputRow.createEl("textarea", {
@ -119,14 +129,20 @@ export class KhojChatView extends KhojPaneView {
chatInput.addEventListener('input', (_) => { this.onChatInput() });
chatInput.addEventListener('keydown', (event) => { this.incrementalChat(event) });
// Add event listeners for long press keybinding
this.contentEl.addEventListener('keydown', this.handleKeyDown.bind(this));
this.contentEl.addEventListener('keyup', this.handleKeyUp.bind(this));
let transcribe = inputRow.createEl("button", {
text: "Transcribe",
attr: {
id: "khoj-transcribe",
class: "khoj-transcribe khoj-input-row-button clickable-icon ",
title: "Start Voice Chat (^S)",
},
})
transcribe.addEventListener('mousedown', async (event) => { await this.speechToText(event) });
transcribe.addEventListener('mousedown', (event) => { this.startSpeechToText(event) });
transcribe.addEventListener('mouseup', async (event) => { await this.stopSpeechToText(event) });
transcribe.addEventListener('touchstart', async (event) => { await this.speechToText(event) });
transcribe.addEventListener('touchend', async (event) => { await this.speechToText(event) });
transcribe.addEventListener('touchcancel', async (event) => { await this.speechToText(event) });
@ -160,6 +176,46 @@ export class KhojChatView extends KhojPaneView {
});
}
startSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent, timeout=200) {
if (!this.keyPressTimeout) {
this.keyPressTimeout = setTimeout(async () => {
// Reset auto send voice message timer, UI if running
if (this.sendMessageTimeout) {
// Stop the auto send voice message countdown timer UI
clearTimeout(this.sendMessageTimeout);
const sendButton = <HTMLButtonElement>this.contentEl.getElementsByClassName("khoj-chat-send")[0]
setIcon(sendButton, "arrow-up-circle")
let sendImg = <SVGElement>sendButton.getElementsByClassName("lucide-arrow-up-circle")[0]
sendImg.addEventListener('click', async (_) => { await this.chat() });
// Reset chat input value
const chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
chatInput.value = "";
}
// Start new voice message
await this.speechToText(event);
}, timeout);
}
}
async stopSpeechToText(event: KeyboardEvent | MouseEvent | TouchEvent) {
if (this.mediaRecorder) {
await this.speechToText(event);
}
if (this.keyPressTimeout) {
clearTimeout(this.keyPressTimeout);
this.keyPressTimeout = null;
}
}
handleKeyDown(event: KeyboardEvent) {
// Start speech to text if keyboard shortcut is pressed
if (event.key === 's' && event.getModifierState('Control')) this.startSpeechToText(event);
}
async handleKeyUp(event: KeyboardEvent) {
// Stop speech to text if keyboard shortcut is released
if (event.key === 's' && event.getModifierState('Control')) await this.stopSpeechToText(event);
}
processOnlineReferences(referenceSection: HTMLElement, onlineContext: any) {
let numOnlineReferences = 0;
for (let subquery in onlineContext) {
@ -294,6 +350,57 @@ export class KhojChatView extends KhojPaneView {
return referenceButton;
}
textToSpeech(message: string, event: MouseEvent | null = null): void {
// Replace the speaker with a loading icon.
let loader = document.createElement("span");
loader.classList.add("loader");
let speechButton: HTMLButtonElement;
let speechIcon: Element;
if (event === null) {
// Pick the last speech button if none is provided
let speechButtons = document.getElementsByClassName("speech-button");
speechButton = speechButtons[speechButtons.length - 1] as HTMLButtonElement;
let speechIcons = document.getElementsByClassName("speech-icon");
speechIcon = speechIcons[speechIcons.length - 1];
} else {
speechButton = event.currentTarget as HTMLButtonElement;
speechIcon = event.target as Element;
}
speechButton.appendChild(loader);
speechButton.disabled = true;
const context = new AudioContext();
let textToSpeechApi = `${this.setting.khojUrl}/api/chat/speech?text=${encodeURIComponent(message)}`;
fetch(textToSpeechApi, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
"Authorization": `Bearer ${this.setting.khojApiKey}`,
},
})
.then(response => response.arrayBuffer())
.then(arrayBuffer => context.decodeAudioData(arrayBuffer))
.then(audioBuffer => {
const source = context.createBufferSource();
source.buffer = audioBuffer;
source.connect(context.destination);
source.start(0);
source.onended = function() {
speechButton.removeChild(loader);
speechButton.disabled = false;
};
})
.catch(err => {
console.error("Error playing speech:", err);
speechButton.removeChild(loader);
speechButton.disabled = false; // Consider enabling the button again to allow retrying
});
}
formatHTMLMessage(message: string, raw = false, willReplace = true) {
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for some AI chat model.
message = message.replace(/<s>\[INST\].+(<\/s>)?/g, '');
@ -461,19 +568,36 @@ export class KhojChatView extends KhojPaneView {
renderActionButtons(message: string, chat_message_body_text_el: HTMLElement) {
let copyButton = this.contentEl.createEl('button');
copyButton.classList.add("copy-button");
copyButton.classList.add("chat-action-button");
copyButton.title = "Copy Message to Clipboard";
setIcon(copyButton, "copy-plus");
copyButton.addEventListener('click', createCopyParentText(message));
chat_message_body_text_el.append(copyButton);
// Add button to paste into current buffer
let pasteToFile = this.contentEl.createEl('button');
pasteToFile.classList.add("copy-button");
pasteToFile.classList.add("chat-action-button");
pasteToFile.title = "Paste Message to File";
setIcon(pasteToFile, "clipboard-paste");
pasteToFile.addEventListener('click', (event) => { pasteTextAtCursor(createCopyParentText(message, 'clipboard-paste')(event)); });
chat_message_body_text_el.append(pasteToFile);
// Only enable the speech feature if the user is subscribed
let speechButton = null;
if (this.setting.userInfo?.is_active) {
// Create a speech button icon to play the message out loud
speechButton = this.contentEl.createEl('button');
speechButton.classList.add("chat-action-button", "speech-button");
speechButton.title = "Listen to Message";
setIcon(speechButton, "speech")
speechButton.addEventListener('click', (event) => this.textToSpeech(message, event));
}
// Append buttons to parent element
chat_message_body_text_el.append(copyButton, pasteToFile);
if (speechButton) {
chat_message_body_text_el.append(speechButton);
}
}
formatDate(date: Date): string {
@ -483,14 +607,16 @@ export class KhojChatView extends KhojPaneView {
return `${time_string}, ${date_string}`;
}
createNewConversation(chatBodyEl: HTMLElement) {
createNewConversation() {
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
chatBodyEl.innerHTML = "";
chatBodyEl.dataset.conversationId = "";
chatBodyEl.dataset.conversationTitle = "";
this.renderMessage(chatBodyEl, "Hey 👋🏾, what's up?", "khoj");
}
async toggleChatSessions(chatBodyEl: HTMLElement, forceShow: boolean = false): Promise<boolean> {
async toggleChatSessions(forceShow: boolean = false): Promise<boolean> {
let chatBodyEl = this.contentEl.getElementsByClassName("khoj-chat-body")[0] as HTMLElement;
if (!forceShow && this.contentEl.getElementsByClassName("side-panel")?.length > 0) {
chatBodyEl.innerHTML = "";
return this.getChatHistory(chatBodyEl);
@ -504,9 +630,10 @@ export class KhojChatView extends KhojPaneView {
const newConversationButtonEl = newConversationEl.createEl("button");
newConversationButtonEl.classList.add("new-conversation-button");
newConversationButtonEl.classList.add("side-panel-button");
newConversationButtonEl.addEventListener('click', (_) => this.createNewConversation(chatBodyEl));
newConversationButtonEl.addEventListener('click', (_) => this.createNewConversation());
setIcon(newConversationButtonEl, "plus");
newConversationButtonEl.innerHTML += "New";
newConversationButtonEl.title = "New Conversation (^N)";
const existingConversationsEl = sidePanelEl.createDiv("existing-conversations");
const conversationListEl = existingConversationsEl.createDiv("conversation-list");
@ -666,7 +793,7 @@ export class KhojChatView extends KhojPaneView {
chatBodyEl.innerHTML = "";
chatBodyEl.dataset.conversationId = "";
chatBodyEl.dataset.conversationTitle = "";
this.toggleChatSessions(chatBodyEl, true);
this.toggleChatSessions(true);
})
.catch(err => {
return;
@ -727,7 +854,7 @@ export class KhojChatView extends KhojPaneView {
return true;
}
async readChatStream(response: Response, responseElement: HTMLDivElement): Promise<void> {
async readChatStream(response: Response, responseElement: HTMLDivElement, isVoice: boolean = false): Promise<void> {
// Exit if response body is empty
if (response.body == null) return;
@ -737,8 +864,12 @@ export class KhojChatView extends KhojPaneView {
while (true) {
const { value, done } = await reader.read();
// Break if the stream is done
if (done) break;
if (done) {
// Automatically respond with voice if the subscribed user has sent voice message
if (isVoice && this.setting.userInfo?.is_active) this.textToSpeech(this.result);
// Break if the stream is done
break;
}
let responseText = decoder.decode(value);
if (responseText.includes("### compiled references:")) {
@ -756,7 +887,7 @@ export class KhojChatView extends KhojPaneView {
}
}
async getChatResponse(query: string | undefined | null): Promise<void> {
async getChatResponse(query: string | undefined | null, isVoice: boolean = false): Promise<void> {
// Exit if query is empty
if (!query || query === "") return;
@ -835,7 +966,7 @@ export class KhojChatView extends KhojPaneView {
}
} else {
// Stream and render chat response
await this.readChatStream(response, responseElement);
await this.readChatStream(response, responseElement, isVoice);
}
} catch (err) {
console.log(`Khoj chat response failed with\n${err}`);
@ -883,7 +1014,7 @@ export class KhojChatView extends KhojPaneView {
sendMessageTimeout: NodeJS.Timeout | undefined;
mediaRecorder: MediaRecorder | undefined;
async speechToText(event: MouseEvent | TouchEvent) {
async speechToText(event: MouseEvent | TouchEvent | KeyboardEvent) {
event.preventDefault();
const transcribeButton = <HTMLButtonElement>this.contentEl.getElementsByClassName("khoj-transcribe")[0];
const chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
@ -916,9 +1047,19 @@ export class KhojChatView extends KhojPaneView {
});
// Parse response from Khoj backend
let noSpeechText: string[] = [
"Thanks for watching!",
"Thanks for watching.",
"Thank you for watching!",
"Thank you for watching.",
"You",
"Bye."
];
let noSpeech: boolean = false;
if (response.status === 200) {
console.log(response);
chatInput.value += response.json.text.trimStart();
noSpeech = noSpeechText.includes(response.json.text.trimStart());
if (!noSpeech) chatInput.value += response.json.text.trimStart();
this.autoResize();
} else if (response.status === 501) {
throw new Error("⛔️ Configure speech-to-text model on server.");
@ -928,8 +1069,8 @@ export class KhojChatView extends KhojPaneView {
throw new Error("⛔️ Failed to transcribe audio.");
}
// Don't auto-send empty messages
if (chatInput.value.length === 0) return;
// Don't auto-send empty messages or when no speech is detected
if (chatInput.value.length === 0 || noSpeech) return;
// Show stop auto-send button. It stops auto-send when clicked
setIcon(sendButton, "stop-circle");
@ -938,6 +1079,7 @@ export class KhojChatView extends KhojPaneView {
// Start the countdown timer UI
stopSendButtonImg.getElementsByTagName("circle")[0].style.animation = "countdown 3s linear 1 forwards";
stopSendButtonImg.getElementsByTagName("circle")[0].style.color = "var(--icon-color-active)";
// Auto send message after 3 seconds
this.sendMessageTimeout = setTimeout(() => {
@ -947,7 +1089,7 @@ export class KhojChatView extends KhojPaneView {
sendImg.addEventListener('click', async (_) => { await this.chat() });
// Send message
this.chat();
this.chat(true);
}, 3000);
};
@ -966,21 +1108,23 @@ export class KhojChatView extends KhojPaneView {
});
this.mediaRecorder.start();
setIcon(transcribeButton, "mic-off");
// setIcon(transcribeButton, "mic-off");
transcribeButton.classList.add("loading-encircle")
};
// Toggle recording
if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive' || event.type === 'touchstart') {
if (!this.mediaRecorder || this.mediaRecorder.state === 'inactive' || event.type === 'touchstart' || event.type === 'mousedown' || event.type === 'keydown') {
navigator.mediaDevices
.getUserMedia({ audio: true })
?.then(handleRecording)
.catch((e) => {
this.flashStatusInChatInput("⛔️ Failed to access microphone");
});
} else if (this.mediaRecorder.state === 'recording' || event.type === 'touchend' || event.type === 'touchcancel') {
} else if (this.mediaRecorder?.state === 'recording' || event.type === 'touchend' || event.type === 'touchcancel' || event.type === 'mouseup' || event.type === 'keyup') {
this.mediaRecorder.stop();
this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
this.mediaRecorder = undefined;
transcribeButton.classList.remove("loading-encircle");
setIcon(transcribeButton, "mic");
}
}

View file

@ -2,7 +2,8 @@ import { Plugin, WorkspaceLeaf } from 'obsidian';
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
import { KhojSearchModal } from 'src/search_modal'
import { KhojChatView } from 'src/chat_view'
import { updateContentIndex, canConnectToBackend, KhojView } from './utils';
import { updateContentIndex, canConnectToBackend, KhojView, jumpToPreviousView } from './utils';
import { KhojPaneView } from './pane_view';
export default class Khoj extends Plugin {
@ -79,16 +80,30 @@ export default class Khoj extends Plugin {
const leaves = workspace.getLeavesOfType(viewType);
if (leaves.length > 0) {
// A leaf with our view already exists, use that
leaf = leaves[0];
// A leaf with our view already exists, use that
leaf = leaves[0];
} else {
// Our view could not be found in the workspace, create a new leaf
// in the right sidebar for it
leaf = workspace.getRightLeaf(false);
await leaf?.setViewState({ type: viewType, active: true });
// Our view could not be found in the workspace, create a new leaf
// in the right sidebar for it
leaf = workspace.getRightLeaf(false);
await leaf?.setViewState({ type: viewType, active: true });
}
// "Reveal" the leaf in case it is in a collapsed sidebar
if (leaf) workspace.revealLeaf(leaf);
}
if (leaf) {
const activeKhojLeaf = workspace.getActiveViewOfType(KhojPaneView)?.leaf;
// Jump to the previous view if the current view is Khoj Side Pane
if (activeKhojLeaf === leaf) jumpToPreviousView();
// Else Reveal the leaf in case it is in a collapsed sidebar
else {
workspace.revealLeaf(leaf);
if (viewType === KhojView.CHAT) {
// focus on the chat input when the chat view is opened
let chatView = leaf.view as KhojChatView;
let chatInput = <HTMLTextAreaElement>chatView.contentEl.getElementsByClassName("khoj-chat-input")[0];
if (chatInput) chatInput.focus();
}
}
}
}
}

View file

@ -38,16 +38,24 @@ export abstract class KhojPaneView extends ItemView {
const leaves = workspace.getLeavesOfType(viewType);
if (leaves.length > 0) {
// A leaf with our view already exists, use that
leaf = leaves[0];
// A leaf with our view already exists, use that
leaf = leaves[0];
} else {
// Our view could not be found in the workspace, create a new leaf
// in the right sidebar for it
leaf = workspace.getRightLeaf(false);
await leaf?.setViewState({ type: viewType, active: true });
// Our view could not be found in the workspace, create a new leaf
// in the right sidebar for it
leaf = workspace.getRightLeaf(false);
await leaf?.setViewState({ type: viewType, active: true });
}
// "Reveal" the leaf in case it is in a collapsed sidebar
if (leaf) workspace.revealLeaf(leaf);
}
if (leaf) {
if (viewType === KhojView.CHAT) {
// focus on the chat input when the chat view is opened
let chatInput = <HTMLTextAreaElement>this.contentEl.getElementsByClassName("khoj-chat-input")[0];
if (chatInput) chatInput.focus();
}
// "Reveal" the leaf in case it is in a collapsed sidebar
workspace.revealLeaf(leaf);
}
}
}

View file

@ -333,6 +333,12 @@ export function createCopyParentText(message: string, originalButton: string = '
}
}
export function jumpToPreviousView() {
const editor: Editor = this.app.workspace.getActiveFileView()?.editor
if (!editor) return;
editor.focus();
}
export function pasteTextAtCursor(text: string | undefined) {
// Get the current active file's editor
const editor: Editor = this.app.workspace.getActiveFileView()?.editor

View file

@ -477,7 +477,7 @@ span.khoj-nav-item-text {
}
/* Copy button */
button.copy-button {
button.chat-action-button {
display: block;
border-radius: 4px;
color: var(--text-muted);
@ -491,20 +491,54 @@ button.copy-button {
margin-top: 8px;
float: right;
}
button.copy-button span {
button.chat-action-button span {
cursor: pointer;
display: inline-block;
position: relative;
transition: 0.5s;
}
button.chat-action-button:hover {
background-color: var(--background-modifier-active-hover);
color: var(--text-normal);
}
img.copy-icon {
width: 16px;
height: 16px;
}
button.copy-button:hover {
background-color: var(--background-modifier-active-hover);
color: var(--text-normal);
/* Circular Loading Spinner */
.loader {
width: 18px;
height: 18px;
border: 3px solid #FFF;
border-radius: 50%;
display: inline-block;
position: relative;
box-sizing: border-box;
animation: rotation 1s linear infinite;
}
.loader::after {
content: '';
box-sizing: border-box;
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
width: 18px;
height: 18px;
border-radius: 50%;
border: 3px solid transparent;
border-bottom-color: var(--flower);
}
@keyframes rotation {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
/* Loading Spinner */
@ -564,6 +598,44 @@ button.copy-button:hover {
}
}
/* Loading Encircle */
.loading-encircle {
position: relative;
}
.loading-encircle::before {
content: '';
position: absolute;
top: 50%;
left: 50%;
width: 24px;
height: 24px;
margin-top: -16px;
margin-left: -16px;
border: 4px solid transparent;
border-color: var(--icon-color-active);
border-radius: 50%;
animation: pulse 3s ease-in-out infinite;
}
@keyframes pulse {
0% {
transform: scale(1);
opacity: 1;
}
50% {
transform: scale(1.2);
opacity: 0.2;
}
100% {
transform: scale(1);
opacity: 1;
}
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
@media only screen and (max-width: 600px) {
div.khoj-header {
display: grid;

View file

@ -52,5 +52,6 @@
"1.12.1": "0.15.0",
"1.13.0": "0.15.0",
"1.14.0": "0.15.0",
"1.15.0": "0.15.0"
"1.15.0": "0.15.0",
"1.16.0": "0.15.0"
}

View file

@ -112,7 +112,7 @@ ASGI_APPLICATION = "app.asgi.application"
# Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20000
DATABASES = {
"default": {
"ENGINE": "django.db.backends.postgresql",
@ -122,6 +122,7 @@ DATABASES = {
"NAME": os.getenv("POSTGRES_DB", "khoj"),
"PASSWORD": os.getenv("POSTGRES_PASSWORD", "postgres"),
"CONN_MAX_AGE": 0,
"CONN_HEALTH_CHECKS": True,
}
}

View file

@ -48,6 +48,7 @@ from khoj.database.models import (
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
)
@ -907,7 +908,45 @@ class ConversationAdapters:
@staticmethod
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
return await TextToImageModelConfig.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod
def get_text_to_image_model_config():
return TextToImageModelConfig.objects.filter().first()
@staticmethod
def get_text_to_image_model_options():
return TextToImageModelConfig.objects.all()
@staticmethod
def get_user_text_to_image_model_config(user: KhojUser):
config = UserTextToImageModelConfig.objects.filter(user=user).first()
if not config:
default_config = ConversationAdapters.get_text_to_image_model_config()
if not default_config:
return None
return default_config
return config.setting
@staticmethod
async def aget_user_text_to_image_model(user: KhojUser) -> Optional[TextToImageModelConfig]:
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
default_config = await ConversationAdapters.aget_text_to_image_model_config()
if not default_config:
return None
return default_config
return config.setting
@staticmethod
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
user=user, defaults={"setting": config}
)
return new_config
@staticmethod
def add_files_to_filter(user: KhojUser, conversation_id: int, files: List[str]):
@ -949,7 +988,7 @@ class FileObjectAdapters:
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
@staticmethod
def get_file_objects_by_name(user: KhojUser, file_name: str):
def get_file_object_by_name(user: KhojUser, file_name: str):
return FileObject.objects.filter(user=user, file_name=file_name).first()
@staticmethod
@ -1005,27 +1044,39 @@ class EntryAdapters:
return deleted_count
@staticmethod
def delete_all_entries_by_type(user: KhojUser, file_type: str = None):
if file_type is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_type=file_type).delete()
def get_filtered_entries(user: KhojUser, file_type: str = None, file_source: str = None):
queryset = Entry.objects.filter(user=user)
if file_type is not None:
queryset = queryset.filter(file_type=file_type)
if file_source is not None:
queryset = queryset.filter(file_source=file_source)
return queryset
@staticmethod
def delete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while queryset.exists():
batch_ids = list(queryset.values_list("id", flat=True)[:batch_size])
batch = Entry.objects.filter(id__in=batch_ids, user=user)
count, _ = batch.delete()
deleted_count += count
return deleted_count
@staticmethod
def delete_all_entries(user: KhojUser, file_source: str = None):
if file_source is None:
deleted_count, _ = Entry.objects.filter(user=user).delete()
else:
deleted_count, _ = Entry.objects.filter(user=user, file_source=file_source).delete()
async def adelete_all_entries(user: KhojUser, file_type: str = None, file_source: str = None, batch_size=1000):
deleted_count = 0
queryset = EntryAdapters.get_filtered_entries(user, file_type, file_source)
while await queryset.aexists():
batch_ids = await sync_to_async(list)(queryset.values_list("id", flat=True)[:batch_size])
batch = Entry.objects.filter(id__in=batch_ids, user=user)
count, _ = await batch.adelete()
deleted_count += count
return deleted_count
@staticmethod
async def adelete_all_entries(user: KhojUser, file_source: str = None):
if file_source is None:
return await Entry.objects.filter(user=user).adelete()
return await Entry.objects.filter(user=user, file_source=file_source).adelete()
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Entry.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)

View file

@ -96,7 +96,6 @@ admin.site.register(SpeechToTextModelOptions)
admin.site.register(SearchModelConfig)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
@ -126,7 +125,10 @@ class EntryAdmin(admin.ModelAdmin):
"file_path",
)
search_fields = ("id", "user__email", "user__username", "file_path")
list_filter = ("file_type",)
list_filter = (
"file_type",
"user__email",
)
ordering = ("-created_at",)
@ -153,6 +155,16 @@ class ChatModelOptionsAdmin(admin.ModelAdmin):
search_fields = ("id", "chat_model", "model_type")
@admin.register(TextToImageModelConfig)
class TextToImageModelOptionsAdmin(admin.ModelAdmin):
list_display = (
"id",
"model_name",
"model_type",
)
search_fields = ("id", "model_name", "model_type")
@admin.register(OpenAIProcessorConversationConfig)
class OpenAIProcessorConversationConfigAdmin(admin.ModelAdmin):
list_display = (

View file

@ -0,0 +1,58 @@
# Generated by Django 4.2.11 on 2024-06-26 03:27
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0048_voicemodeloption_uservoicemodelconfig"),
]
operations = [
migrations.AddField(
model_name="texttoimagemodelconfig",
name="api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
field=models.CharField(
choices=[("openai", "Openai"), ("stability-ai", "Stabilityai")], default="openai", max_length=200
),
),
migrations.CreateModel(
name="UserTextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"setting",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="database.texttoimagemodelconfig"
),
),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View file

@ -0,0 +1,14 @@
# Generated by Django 4.2.11 on 2024-07-02 12:20
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0049_texttoimagemodelconfig_api_key_and_more"),
("database", "0050_alter_processlock_name"),
]
operations: List[str] = []

View file

@ -215,11 +215,11 @@ class SearchModelConfig(BaseModel):
# Bi-encoder model of sentence-transformer type to load from HuggingFace
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
# Config passed to the sentence-transformer model constructor. E.g. device="cuda:0", trust_remote_server=True etc.
bi_encoder_model_config = models.JSONField(default=dict)
bi_encoder_model_config = models.JSONField(default=dict, blank=True)
# Query encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_query_encode_config = models.JSONField(default=dict)
bi_encoder_query_encode_config = models.JSONField(default=dict, blank=True)
# Docs encode configs like prompt, precision, normalize_embeddings, etc. for sentence-transformer models
bi_encoder_docs_encode_config = models.JSONField(default=dict)
bi_encoder_docs_encode_config = models.JSONField(default=dict, blank=True)
# Cross-encoder model of sentence-transformer type to load from HuggingFace
cross_encoder = models.CharField(max_length=200, default="mixedbread-ai/mxbai-rerank-xsmall-v1")
# Inference server API endpoint to use for embeddings inference. Bi-encoder model should be hosted on this server
@ -235,9 +235,37 @@ class SearchModelConfig(BaseModel):
class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)
def clean(self):
# Custom validation logic
error = {}
if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.openai_config:
error[
"api_key"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
error[
"openai_config"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
if self.model_type != self.ModelType.OPENAI:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI models."
if self.openai_config:
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel):
@ -264,6 +292,11 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)

View file

@ -242,18 +242,25 @@
<script>
async function openChat(agentSlug) {
// Create a loading animation
let loading = document.createElement("div");
loading.innerHTML = '<div>Booting your agent...</div><span class="loader"></span>';
loading.style.position = "fixed";
loading.style.top = "0";
loading.style.right = "0";
loading.style.bottom = "0";
loading.style.left = "0";
loading.style.display = "flex";
loading.style.justifyContent = "center";
loading.style.alignItems = "center";
loading.style.backgroundColor = "rgba(0, 0, 0, 0.5)"; // Semi-transparent black
document.body.appendChild(loading);
let loadingTextEl = document.createElement("div");
loadingTextEl.textContent = 'Booting your agent...';
let loadingAnimationEl = document.createElement("span");
loadingAnimationEl.className = "loader";
let loadingEl = document.createElement("div");
loadingEl.style.position = "fixed";
loadingEl.style.top = "0";
loadingEl.style.right = "0";
loadingEl.style.bottom = "0";
loadingEl.style.left = "0";
loadingEl.style.display = "flex";
loadingEl.style.justifyContent = "center";
loadingEl.style.alignItems = "center";
loadingEl.style.backgroundColor = "rgba(0, 0, 0, 0.5)"; // Semi-transparent black
loadingEl.append(loadingTextEl, loadingAnimationEl);
document.body.appendChild(loadingEl);
let response = await fetch(`/api/chat/sessions?agent_slug=${agentSlug}`, { method: "POST" });
let data = await response.json();

View file

@ -5,13 +5,22 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
<title>Khoj</title>
<meta http-equiv="Content-Security-Policy"
content="default-src 'self' https://assets.khoj.dev;
script-src 'self' https://assets.khoj.dev 'unsafe-inline';
connect-src 'self' https://ipapi.co/json;
style-src 'self' https://assets.khoj.dev 'unsafe-inline' https://fonts.googleapis.com;
img-src 'self' data: https://*.khoj.dev https://*.googleusercontent.com;
font-src https://assets.khoj.dev https://fonts.gstatic.com;
child-src 'none';
object-src 'none';">
<link rel="stylesheet" href="/static/assets/pico.min.css?v={{ khoj_version }}">
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
<script
integrity="sha384-05IkdNHoAlkhrFVUCCN805WC/h4mcI98GUBssmShF2VJAXKyZTrO/TmJ+4eBo0Cy"
crossorigin="anonymous"
src="https://cdnjs.cloudflare.com/ajax/libs/intl-tel-input/17.0.13/js/intlTelInput.min.js"></script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/intl-tel-input/17.0.13/css/intlTelInput.css">
src="https://assets.khoj.dev/intl-tel-input/intlTelInput.min.js"></script>
<link rel="stylesheet" href="https://assets.khoj.dev/intl-tel-input/intlTelInput.css">
</head>
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
<script type="text/javascript" src="/static/assets/purify.min.js?v={{ khoj_version }}"></script>
@ -332,6 +341,7 @@
margin: 20px;
}
select#paint-models,
select#search-models,
select#voice-models,
select#chat-models {

View file

@ -48,8 +48,8 @@ Get the Khoj [Desktop](https://khoj.dev/downloads), [Obsidian](https://docs.khoj
To get started, just start typing below. You can also type / to see a list of commands.
`.trim()
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'docx'];
const allowedExtensions = ['text/org', 'text/markdown', 'text/plain', 'text/html', 'application/pdf', 'image/jpeg', 'image/png', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'];
const allowedFileEndings = ['org', 'md', 'txt', 'html', 'pdf', 'jpg', 'jpeg', 'png', 'docx'];
let chatOptions = [];
function createCopyParentText(message) {
return function(event) {
@ -149,7 +149,6 @@ To get started, just start typing below. You can also type / to see a list of co
}
function generateOnlineReference(reference, index) {
// Generate HTML for Chat Reference
let title = reference.title || reference.link;
let link = reference.link;
@ -170,7 +169,7 @@ To get started, just start typing below. You can also type / to see a list of co
linkElement.textContent = title;
let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML;
referenceButton.appendChild(linkElement);
referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed");
@ -181,11 +180,12 @@ To get started, just start typing below. You can also type / to see a list of co
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`;
this.innerHTML = `${linkElement.outerHTML}<br><br>${question}${snippet}`;
} else {
this.classList.add("collapsed");
this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML;
this.innerHTML = "";
this.appendChild(linkElement);
}
});
@ -578,7 +578,7 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.textContent = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
@ -888,7 +888,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (overlayText == null) {
dropzone.classList.add('dragover');
var overlayText = document.createElement("div");
overlayText.innerHTML = "Select file(s) or drag + drop it here to share it with Khoj";
overlayText.textContent = "Select file(s) or drag + drop it here to share it with Khoj";
overlayText.className = "dropzone-overlay";
overlayText.id = "dropzone-overlay";
dropzone.appendChild(overlayText);
@ -949,7 +949,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (overlayText != null) {
// Display loading spinner
var loadingSpinner = document.createElement("div");
overlayText.innerHTML = "Uploading file(s) for indexing";
overlayText.textContent = "Uploading file(s) for indexing";
loadingSpinner.className = "spinner";
overlayText.appendChild(loadingSpinner);
}
@ -974,7 +974,12 @@ To get started, just start typing below. You can also type / to see a list of co
fileType = "text/html";
} else if (fileExtension === "pdf") {
fileType = "application/pdf";
} else {
} else if (fileExtension === "jpg" || fileExtension === "jpeg"){
fileType = "image/jpeg";
} else if (fileExtension === "png") {
fileType = "image/png";
}
else {
// Skip this file if its type is not supported
resolve();
return;
@ -1037,7 +1042,7 @@ To get started, just start typing below. You can also type / to see a list of co
if (overlayText == null) {
var overlayText = document.createElement("div");
overlayText.innerHTML = "Drop file to share it with Khoj";
overlayText.textContent = "Drop file to share it with Khoj";
overlayText.className = "dropzone-overlay";
overlayText.id = "dropzone-overlay";
this.appendChild(overlayText);
@ -1174,11 +1179,15 @@ To get started, just start typing below. You can also type / to see a list of co
websocket.onclose = function(event) {
websocket = null;
console.log("WebSocket is closed now.");
let setupWebSocketButton = document.createElement("button");
setupWebSocketButton.textContent = "Reconnect to Server";
setupWebSocketButton.onclick = setupWebSocket;
let statusDotIcon = document.getElementById("connection-status-icon");
statusDotIcon.style.backgroundColor = "red";
let statusDotText = document.getElementById("connection-status-text");
statusDotText.innerHTML = "";
statusDotText.style.marginTop = "5px";
statusDotText.innerHTML = '<button onclick="setupWebSocket()">Reconnect to Server</button>';
statusDotText.appendChild(setupWebSocketButton);
}
websocket.onerror = function(event) {
console.log("WebSocket error observed:", event);
@ -1429,7 +1438,7 @@ To get started, just start typing below. You can also type / to see a list of co
questionStarterSuggestions.innerHTML = "";
data.forEach((questionStarter) => {
let questionStarterButton = document.createElement('button');
questionStarterButton.innerHTML = questionStarter;
questionStarterButton.textContent = questionStarter;
questionStarterButton.classList.add("question-starter");
questionStarterButton.addEventListener('click', function() {
questionStarterSuggestions.style.display = "none";
@ -1601,7 +1610,7 @@ To get started, just start typing below. You can also type / to see a list of co
let closeButton = document.createElement('button');
closeButton.id = "close-button";
closeButton.innerHTML = "Close";
closeButton.textContent = "Close";
closeButton.classList.add("close-button");
closeButton.addEventListener('click', function() {
modal.remove();
@ -1655,7 +1664,7 @@ To get started, just start typing below. You can also type / to see a list of co
let threeDotMenu = document.createElement('div');
threeDotMenu.classList.add("three-dot-menu");
let threeDotMenuButton = document.createElement('button');
threeDotMenuButton.innerHTML = "⋮";
threeDotMenuButton.textContent = "⋮";
threeDotMenuButton.classList.add("three-dot-menu-button");
threeDotMenuButton.addEventListener('click', function(event) {
event.stopPropagation();
@ -1674,7 +1683,7 @@ To get started, just start typing below. You can also type / to see a list of co
conversationMenu.classList.add("conversation-menu");
let editTitleButton = document.createElement('button');
editTitleButton.innerHTML = "Rename";
editTitleButton.textContent = "Rename";
editTitleButton.classList.add("edit-title-button");
editTitleButton.classList.add("three-dot-menu-button-item");
editTitleButton.addEventListener('click', function(event) {
@ -1708,7 +1717,7 @@ To get started, just start typing below. You can also type / to see a list of co
conversationTitleInputBox.appendChild(conversationTitleInput);
let conversationTitleInputButton = document.createElement('button');
conversationTitleInputButton.innerHTML = "Save";
conversationTitleInputButton.textContent = "Save";
conversationTitleInputButton.classList.add("three-dot-menu-button-item");
conversationTitleInputButton.addEventListener('click', function(event) {
event.stopPropagation();
@ -1732,7 +1741,7 @@ To get started, just start typing below. You can also type / to see a list of co
threeDotMenu.appendChild(conversationMenu);
let shareButton = document.createElement('button');
shareButton.innerHTML = "Share";
shareButton.textContent = "Share";
shareButton.type = "button";
shareButton.classList.add("share-conversation-button");
shareButton.classList.add("three-dot-menu-button-item");
@ -1799,7 +1808,7 @@ To get started, just start typing below. You can also type / to see a list of co
let deleteButton = document.createElement('button');
deleteButton.type = "button";
deleteButton.innerHTML = "Delete";
deleteButton.textContent = "Delete";
deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function(event) {
@ -1963,12 +1972,16 @@ To get started, just start typing below. You can also type / to see a list of co
}
allFiles = data;
var nofilesmessage = document.getElementsByClassName("no-files-message")[0];
nofilesmessage.innerHTML = "";
if(allFiles.length === 0){
nofilesmessage.innerHTML = `<a class="inline-chat-link" href="https://docs.khoj.dev/category/clients/">How to upload files</a>`;
let inlineChatLinkEl = document.createElement('a');
inlineChatLinkEl.className = "inline-chat-link";
inlineChatLinkEl.href = "https://docs.khoj.dev/category/clients/";
inlineChatLinkEl.textContent = "How to upload files";
nofilesmessage.appendChild(inlineChatLinkEl);
document.getElementsByClassName("file-toggle-button")[0].style.display = "none";
}
else{
nofilesmessage.innerHTML = "";
document.getElementsByClassName("file-toggle-button")[0].style.display = "block";
}
})

View file

@ -163,10 +163,6 @@
<div class="section-cards">
<div class="finalize-buttons">
<button id="sync" type="submit" title="Regenerate index from scratch for Notion, GitHub configuration" style="display: flex; justify-content: center;">
<img class="card-icon" src="/static/assets/icons/sync.svg" alt="Sync">
<h3 class="card-title">
Sync
</h3>
</button>
</div>
</div>
@ -192,11 +188,37 @@
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-model" class="card-button happy" onclick="updateChatModel()">
<button id="save-chat-model" class="card-button happy" onclick="updateChatModel()">
Save
</button>
{% else %}
<button id="save-model" class="card-button" disabled>
<button id="save-chat-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
</div>
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
<span>Paint</span>
</h3>
</div>
<div class="card-description-row">
<select id="paint-models">
{% for option in paint_model_options %}
<option value="{{ option.id }}" {% if option.id == selected_paint_model_config %}selected{% endif %}>{{ option.model_name }}</option>
{% endfor %}
</select>
</div>
<div class="card-action-row">
{% if (not billing_enabled) or (subscription_state != 'unsubscribed' and subscription_state != 'expired') %}
<button id="save-paint-model" class="card-button happy" onclick="updatePaintModel()">
Save
</button>
{% else %}
<button id="save-paint-model" class="card-button" disabled>
Subscribe to use different models
</button>
{% endif %}
@ -382,7 +404,8 @@
.then(data => {
if (data.status == "ok") {
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Profile name has been updated!";
notificationBanner.innerHTML = "";
notificationBanner.textContent = "Profile name has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
@ -394,8 +417,9 @@
function updateVoiceModel() {
const voiceModel = document.getElementById("voice-models").value;
const saveVoiceModelButton = document.getElementById("save-voice-model");
saveVoiceModelButton.innerHTML = "";
saveVoiceModelButton.disabled = true;
saveVoiceModelButton.innerHTML = "Saving...";
saveVoiceModelButton.textContent = "Saving...";
fetch('/api/config/data/voice/model?id=' + voiceModel, {
method: 'POST',
@ -406,18 +430,19 @@
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveVoiceModelButton.innerHTML = "Save";
saveVoiceModelButton.textContent = "Save";
saveVoiceModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Voice model has been updated!";
notificationBanner.innerHTML = "";
notificationBanner.textContent = "Voice model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveVoiceModelButton.innerHTML = "Error";
saveVoiceModelButton.textContent = "Error";
saveVoiceModelButton.disabled = false;
}
})
@ -425,9 +450,10 @@
function updateChatModel() {
const chatModel = document.getElementById("chat-models").value;
const saveModelButton = document.getElementById("save-model");
const saveModelButton = document.getElementById("save-chat-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
saveModelButton.innerHTML = "";
saveModelButton.textContent = "Saving...";
fetch('/api/config/data/conversation/model?id=' + chatModel, {
method: 'POST',
@ -438,18 +464,19 @@
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveModelButton.innerHTML = "Save";
saveModelButton.textContent = "Save";
saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Conversation model has been updated!";
notificationBanner.innerHTML = "";
notificationBanner.textContent = "Conversation model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveModelButton.innerHTML = "Error";
saveModelButton.textContent = "Error";
saveModelButton.disabled = false;
}
})
@ -463,8 +490,9 @@
const searchModel = document.getElementById("search-models").value;
const saveSearchModelButton = document.getElementById("save-search-model");
saveSearchModelButton.innerHTML = "";
saveSearchModelButton.disabled = true;
saveSearchModelButton.innerHTML = "Saving...";
saveSearchModelButton.textContent = "Saving...";
fetch('/api/config/data/search/model?id=' + searchModel, {
method: 'POST',
@ -475,15 +503,16 @@
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveSearchModelButton.innerHTML = "Save";
saveSearchModelButton.textContent = "Save";
saveSearchModelButton.disabled = false;
} else {
saveSearchModelButton.innerHTML = "Error";
saveSearchModelButton.textContent = "Error";
saveSearchModelButton.disabled = false;
}
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Khoj can now better understand the language of your content! Manually sync your data from one of the Khoj clients to update your knowledge base.";
notificationBanner.innerHTML = "";
notificationBanner.textContent = "Khoj can now better understand the language of your content! Manually sync your data from one of the Khoj clients to update your knowledge base.";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
@ -491,6 +520,38 @@
})
};
function updatePaintModel() {
const paintModel = document.getElementById("paint-models").value;
const saveModelButton = document.getElementById("save-paint-model");
saveModelButton.disabled = true;
saveModelButton.innerHTML = "Saving...";
fetch('/api/config/data/paint/model?id=' + paintModel, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
saveModelButton.innerHTML = "Save";
saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Paint model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveModelButton.innerHTML = "Error";
saveModelButton.disabled = false;
}
})
};
function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE',
@ -549,23 +610,38 @@
})
}
var sync = document.getElementById("sync");
sync.addEventListener("click", function(event) {
function populateSyncButton() {
let syncIconEl = document.createElement("img");
syncIconEl.className = "card-icon";
syncIconEl.src = "/static/assets/icons/sync.svg";
syncIconEl.alt = "Sync";
let syncButtonTitleEl = document.createElement("h3");
syncButtonTitleEl.className = "card-title";
syncButtonTitleEl.textContent = "Sync";
return [syncButtonTitleEl, syncIconEl];
}
var syncButtonEl = document.getElementById("sync");
syncButtonEl.innerHTML = "";
syncButtonEl.append(...populateSyncButton());
syncButtonEl.addEventListener("click", function(event) {
event.preventDefault();
updateIndex(
force=true,
successText="Synced!",
errorText="Unable to sync. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
button=sync,
button=syncButtonEl,
loadingText="Syncing...",
emoji="");
});
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
const original_html = button.innerHTML;
button.disabled = true;
button.innerHTML = emoji + " " + loadingText;
button.innerHTML = ""
button.textContent = emoji + " " + loadingText;
fetch('/api/update?&client=web&force=' + force, {
method: 'GET',
headers: {
@ -582,19 +658,19 @@
document.getElementById("status").style.display = "none";
button.disabled = false;
button.innerHTML = `✅ ${successText}`;
button.textContent = `✅ ${successText}`;
setTimeout(function() {
button.innerHTML = original_html;
button.append(...populateSyncButton());
}, 2000);
})
.catch((error) => {
console.error('Error:', error);
document.getElementById("status").innerHTML = emoji + " " + errorText
document.getElementById("status").textContent = emoji + " " + errorText
document.getElementById("status").style.display = "block";
button.disabled = false;
button.innerHTML = '⚠️ Unsuccessful';
button.textContent = '⚠️ Unsuccessful';
setTimeout(function() {
button.innerHTML = original_html;
button.append(...populateSyncButton());
}, 2000);
});
@ -629,7 +705,7 @@
})
.then(response => response.json())
.then(tokenObj => {
apiKeyList.innerHTML += generateTokenRow(tokenObj);
apiKeyList.appendChild(generateTokenRow(tokenObj));
});
}
@ -638,16 +714,16 @@
navigator.clipboard.writeText(token);
// Flash the API key copied icon
const apiKeyColumn = document.getElementById(`api-key-${token}`);
const original_html = apiKeyColumn.innerHTML;
const original_text = apiKeyColumn.textContent;
const copyApiKeyButton = document.getElementById(`api-key-copy-${token}`);
setTimeout(function() {
copyApiKeyButton.src = "/static/assets/icons/copy-button-success.svg";
setTimeout(() => {
copyApiKeyButton.src = "/static/assets/icons/copy-button.svg";
}, 1000);
apiKeyColumn.innerHTML = "✅ Copied!";
apiKeyColumn.textContent = "✅ Copied!";
setTimeout(function() {
apiKeyColumn.innerHTML = original_html;
apiKeyColumn.textContent = original_text;
}, 1000);
}, 100);
}
@ -670,16 +746,50 @@
let tokenName = tokenObj.name;
let truncatedToken = token.slice(0, 4) + "..." + token.slice(-4);
let tokenId = `${tokenName}-${truncatedToken}`;
return `
<tr id="api-key-item-${token}">
<td><b>${tokenName}</b></td>
<td id="api-key-${token}">${truncatedToken}</td>
<td>
<img id="api-key-copy-${token}" onclick="copyAPIKey('${token}')" class="configured-icon api-key-action enabled" src="/static/assets/icons/copy-button.svg" alt="Copy API Key" title="Copy API Key">
<img id="api-key-delete-${token}" onclick="deleteAPIKey('${token}')" class="configured-icon api-key-action enabled" src="/static/assets/icons/delete.svg" alt="Delete API Key" title="Delete API Key">
</td>
</tr>
`;
// Create API Key Row
let apiKeyItemEl = document.createElement("tr");
apiKeyItemEl.id = `api-key-item-${token}`;
// API Key Name Row
let apiKeyNameEl = document.createElement("td");
let apiKeyNameTextEl = document.createElement("b");
apiKeyNameTextEl.textContent = tokenName;
// API Key Token Row
let apiKeyTokenEl = document.createElement("td");
apiKeyTokenEl.id = `api-key-${token}`;
apiKeyTokenEl.textContent = truncatedToken;
// API Key Actions Row
let apiKeyActionsEl = document.createElement("td");
// Copy API Key Button
let copyApiKeyButtonEl = document.createElement("img");
copyApiKeyButtonEl.id = `api-key-copy-${token}`;
copyApiKeyButtonEl.className = "configured-icon api-key-action enabled";
copyApiKeyButtonEl.src = "/static/assets/icons/copy-button.svg";
copyApiKeyButtonEl.alt = "Copy API Key";
copyApiKeyButtonEl.title = "Copy API Key";
copyApiKeyButtonEl.onclick = function() {
copyAPIKey(token);
};
// Delete API Key Button
let deleteApiKeyButtonEl = document.createElement("img");
deleteApiKeyButtonEl.id = `api-key-delete-${token}`;
deleteApiKeyButtonEl.className = "configured-icon api-key-action enabled";
deleteApiKeyButtonEl.src = "/static/assets/icons/delete.svg";
deleteApiKeyButtonEl.alt = "Delete API Key";
deleteApiKeyButtonEl.title = "Delete API Key";
deleteApiKeyButtonEl.onclick = function() {
deleteAPIKey(token);
};
// Construct the API Key Row
apiKeyNameEl.append(apiKeyNameTextEl);
apiKeyActionsEl.append(copyApiKeyButtonEl, deleteApiKeyButtonEl);
apiKeyItemEl.append(apiKeyNameEl, apiKeyTokenEl, apiKeyActionsEl);
return apiKeyItemEl;
}
function listApiKeys() {
@ -688,7 +798,7 @@
.then(response => response.json())
.then(tokens => {
if (!tokens?.length > 0) return;
apiKeyList.innerHTML = tokens?.map(generateTokenRow).join("");
apiKeyList.append(...tokens?.map(generateTokenRow));
});
}
@ -696,11 +806,11 @@
listApiKeys();
function getIndexedDataSize() {
document.getElementById("indexed-data-size").innerHTML = "Calculating...";
document.getElementById("indexed-data-size").textContent = "Calculating...";
fetch('/api/config/index/size')
.then(response => response.json())
.then(data => {
document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used";
document.getElementById("indexed-data-size").textContent = data.indexed_data_size_in_mb + " MB used";
});
}
@ -729,7 +839,7 @@
.catch(() => callback("us"))
},
separateDialCode: true,
utilsScript: "https://cdn.jsdelivr.net/npm/intl-tel-input@18.2.1/build/js/utils.js",
utilsScript: "https://assets.khoj.dev/intl-tel-input/utils.js",
});
const errorMap = ["Invalid number", "Invalid country code", "Too short", "Too long", "Invalid number"];
@ -800,7 +910,7 @@
phonenumberVerifyButton.addEventListener("click", () => {
console.log(iti.getValidationError());
if (iti.isValidNumber() == false) {
phoneNumberUpdateCallback.innerHTML = "Invalid phone number: " + errorMap[iti.getValidationError()];
phoneNumberUpdateCallback.textContent = "Invalid phone number: " + errorMap[iti.getValidationError()];
phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none";
@ -817,12 +927,12 @@
.then(data => {
if (data.status == "ok") {
if (isTwilioEnabled == "True" || isTwilioEnabled == "true") {
phoneNumberUpdateCallback.innerHTML = "OTP sent to your phone number";
phoneNumberUpdateCallback.textContent = "OTP sent to your phone number";
phonenumberVerifyOTPButton.style.display = "block";
phonenumberOTPInput.style.display = "block";
} else {
phonenumberVerifiedText.style.display = "block";
phoneNumberUpdateCallback.innerHTML = "Phone number updated";
phoneNumberUpdateCallback.textContent = "Phone number updated";
phonenumberUnverifiedText.style.display = "none";
}
phonenumberVerifyButton.style.display = "none";
@ -831,7 +941,7 @@
phoneNumberUpdateCallback.style.display = "none";
}, 5000);
} else {
phoneNumberUpdateCallback.innerHTML = "Error updating phone number";
phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none";
@ -840,7 +950,7 @@
})
.catch((error) => {
console.error('Error:', error);
phoneNumberUpdateCallback.innerHTML = "Error updating phone number";
phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none";
@ -852,7 +962,7 @@
phonenumberVerifyOTPButton.addEventListener("click", () => {
const otp = phonenumberOTPInput.value;
if (otp.length != 6) {
phoneNumberUpdateCallback.innerHTML = "Your OTP should be exactly 6 digits";
phoneNumberUpdateCallback.textContent = "Your OTP should be exactly 6 digits";
phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none";
@ -869,7 +979,7 @@
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
phoneNumberUpdateCallback.innerHTML = "Phone number updated";
phoneNumberUpdateCallback.textContent = "Phone number updated";
phonenumberVerifiedText.style.display = "block";
phonenumberUnverifiedText.style.display = "none";
phoneNumberUpdateCallback.style.display = "block";
@ -881,7 +991,7 @@
phoneNumberUpdateCallback.style.display = "none";
}, 5000);
} else {
phoneNumberUpdateCallback.innerHTML = "Error updating phone number";
phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none";
@ -890,7 +1000,7 @@
})
.catch((error) => {
console.error('Error:', error);
phoneNumberUpdateCallback.innerHTML = "Error updating phone number";
phoneNumberUpdateCallback.textContent = "Error updating phone number";
phoneNumberUpdateCallback.style.display = "block";
setTimeout(function() {
phoneNumberUpdateCallback.style.display = "none";

View file

@ -12,7 +12,7 @@
</h2>
<div class="section-manage-files">
<div id="delete-all-files" class="delete-all-files">
<button id="delete-all-files" type="submit" title="Remove all computer files from Khoj">🗑️ Delete all</button>
<button id="delete-all-files-button" type="submit" title="Remove all computer files from Khoj">🗑️ Delete all</button>
</div>
<div class="indexed-files">
</div>
@ -56,7 +56,10 @@
if (data.length == 0) {
document.getElementById("delete-all-files").style.display = "none";
indexedFiles.innerHTML = "<div class='card-description'>No documents synced with Khoj</div>";
let noFilesElement = document.createElement("div");
noFilesElement.classList.add("card-description");
noFilesElement.textContent = "No documents synced with Khoj";
indexedFiles.appendChild(noFilesElement);
} else {
document.getElementById("get-desktop-client").style.display = "none";
document.getElementById("delete-all-files").style.display = "block";
@ -86,14 +89,14 @@
let fileNameElement = document.createElement("div");
fileNameElement.classList.add("content-name");
fileNameElement.innerHTML = filename;
fileNameElement.textContent = filename;
fileElement.appendChild(fileNameElement);
let buttonContainer = document.createElement("div");
buttonContainer.classList.add("remove-button-container");
let removeFileButton = document.createElement("button");
removeFileButton.classList.add("remove-file-button");
removeFileButton.innerHTML = "🗑️";
removeFileButton.textContent = "🗑️";
removeFileButton.addEventListener("click", ((filename) => {
return () => {
removeFile(filename);
@ -112,9 +115,13 @@
// Get all currently indexed files on page load
getAllComputerFilenames();
let deleteAllComputerFilesButton = document.getElementById("delete-all-files");
let deleteAllComputerFilesButton = document.getElementById("delete-all-files-button");
deleteAllComputerFilesButton.addEventListener("click", function(event) {
event.preventDefault();
originalDeleteAllComputerFilesButtonText = deleteAllComputerFilesButton.textContent;
deleteAllComputerFilesButton.textContent = "🗑️ Deleting...";
deleteAllComputerFilesButton.disabled = true;
fetch('/api/config/data/content-source/computer', {
method: 'DELETE',
headers: {
@ -122,11 +129,11 @@
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
getAllComputerFilenames();
}
})
.finally(() => {
getAllComputerFilenames();
deleteAllComputerFilesButton.textContent = originalDeleteAllComputerFilesButtonText;
deleteAllComputerFilesButton.disabled = false;
});
});
</script>
{% endblock %}

View file

@ -70,18 +70,50 @@
repo.classList.add("repo");
const id = Date.now();
repo.id = "repo-card-" + id;
repo.innerHTML = `
<label for="repo-owner">Repository Owner</label>
<input type="text" id="repo-owner" name="repo_owner">
<label for="repo-name">Repository Name</label>
<input type="text" id="repo-name" name="repo_name">
<label for="repo-branch">Repository Branch</label>
<input type="text" id="repo-branch" name="repo_branch">
<button type="button"
class="remove-repo-button"
onclick="remove_repo(${id})"
id="remove-repo-button-${id}">Remove Repository</button>
`;
// Create repo owner, name, branch elements
let repoOwnerLabel = document.createElement("label");
repoOwnerLabel.textContent = "Repository Owner";
repoOwnerLabel.for = "repo-owner";
let repoOwner = document.createElement("input");
repoOwner.type = "text";
repoOwner.id = "repo-owner-" + id;
repoOwner.name = "repo_owner";
let repoNameLabel = document.createElement("label");
repoNameLabel.textContent = "Repository Name";
repoNameLabel.for = "repo-name";
let repoName = document.createElement("input");
repoName.type = "text";
repoName.id = "repo-name-" + id;
repoName.name = "repo_name";
let repoBranchLabel = document.createElement("label");
repoBranchLabel.textContent = "Repository Branch";
repoBranchLabel.for = "repo-branch";
let repoBranch = document.createElement("input");
repoBranch.type = "text";
repoBranch.id = "repo-branch-" + id;
repoBranch.name = "repo_branch";
let removeRepoButton = document.createElement("button");
removeRepoButton.type = "button";
removeRepoButton.classList.add("remove-repo-button");
removeRepoButton.onclick = function() { remove_repo(id); };
removeRepoButton.id = "remove-repo-button-" + id;
removeRepoButton.textContent = "Remove Repository";
// Append elements to repo card
repo.append(
repoOwnerLabel, repoOwner,
repoNameLabel, repoName,
repoBranchLabel, repoBranch,
removeRepoButton
);
document.getElementById("repositories").appendChild(repo);
})
@ -95,7 +127,7 @@
const pat_token = document.getElementById("pat-token").value;
if (pat_token == "") {
document.getElementById("success").innerHTML = "❌ Please enter a Personal Access Token.";
document.getElementById("success").textContent = "❌ Please enter a Personal Access Token.";
document.getElementById("success").style.display = "block";
return;
}
@ -122,14 +154,14 @@
}
if (repos.length == 0) {
document.getElementById("success").innerHTML = "❌ Please add at least one repository.";
document.getElementById("success").textContent = "❌ Please add at least one repository.";
document.getElementById("success").style.display = "block";
return;
}
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
submitButton.innerHTML = "Saving...";
submitButton.textContent = "Saving...";
// Save Github config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
@ -147,11 +179,11 @@
.then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings.";
document.getElementById("success").textContent = "⚠️ Failed to save Github settings.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings";
submitButton.textContent = "⚠️ Failed to save settings";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.textContent = "Save";
submitButton.disabled = false;
}, 2000);
return;
@ -163,18 +195,18 @@
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => {
document.getElementById("success").style.display = "none";
submitButton.innerHTML = "✅ Successfully updated";
submitButton.textContent = "✅ Successfully updated";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.textContent = "Save";
submitButton.disabled = false;
}, 2000);
})
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content.";
document.getElementById("success").textContent = "⚠️ Failed to save Github content.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content";
submitButton.textContent = "⚠️ Failed to save content";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.textContent = "Save";
submitButton.disabled = false;
}, 2000);
});

View file

@ -34,14 +34,14 @@
const token = document.getElementById("token").value;
if (token == "") {
document.getElementById("success").innerHTML = "❌ Please enter a Notion Token.";
document.getElementById("success").textContent = "❌ Please enter a Notion Token.";
document.getElementById("success").style.display = "block";
return;
}
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
submitButton.innerHTML = "Syncing...";
submitButton.textContent = "Syncing...";
// Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
@ -58,11 +58,11 @@
.then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
document.getElementById("success").textContent = "⚠️ Failed to save Notion settings.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings";
submitButton.textContent = "⚠️ Failed to save settings";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.textContent = "Save";
submitButton.disabled = false;
}, 2000);
return;
@ -74,18 +74,18 @@
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => {
document.getElementById("success").style.display = "none";
submitButton.innerHTML = "✅ Successfully updated";
submitButton.textContent = "✅ Successfully updated";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.textContent = "Save";
submitButton.disabled = false;
}, 2000);
})
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
document.getElementById("success").textContent = "⚠️ Failed to save Notion content.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content";
submitButton.textContent = "⚠️ Failed to save content";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.textContent = "Save";
submitButton.disabled = false;
}, 2000);
});

View file

@ -127,7 +127,7 @@ To get started, just start typing below. You can also type / to see a list of co
linkElement.textContent = title;
let referenceButton = document.createElement('button');
referenceButton.innerHTML = linkElement.outerHTML;
referenceButton.appendChild(linkElement);
referenceButton.id = `ref-${index}`;
referenceButton.classList.add("reference-button");
referenceButton.classList.add("collapsed");
@ -138,11 +138,12 @@ To get started, just start typing below. You can also type / to see a list of co
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
this.innerHTML = linkElement.outerHTML + `<br><br>${question + snippet}`;
this.innerHTML = `${linkElement.outerHTML}<br><br>${question + snippet}`;
} else {
this.classList.add("collapsed");
this.classList.remove("expanded");
this.innerHTML = linkElement.outerHTML;
this.innerHTML = "";
this.appendChild(linkElement);
}
});
@ -296,7 +297,7 @@ To get started, just start typing below. You can also type / to see a list of co
}
let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.innerHTML = expandButtonText;
referenceExpandButton.textContent = expandButtonText;
references.appendChild(referenceSection);
@ -447,7 +448,7 @@ To get started, just start typing below. You can also type / to see a list of co
let referenceExpandButton = document.createElement('button');
referenceExpandButton.classList.add("reference-expand-button");
referenceExpandButton.innerHTML = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.textContent = numReferences == 1 ? "1 reference" : `${numReferences} references`;
referenceExpandButton.addEventListener('click', function() {
if (referenceSection.classList.contains("collapsed")) {
@ -815,7 +816,7 @@ Learn more [here](https://khoj.dev).
let closeButton = document.createElement('button');
closeButton.id = "close-button";
closeButton.innerHTML = "Close";
closeButton.textContent = "Close";
closeButton.classList.add("close-button");
closeButton.addEventListener('click', function() {
modal.remove();

View file

@ -0,0 +1,118 @@
import base64
import logging
import os
from datetime import datetime
from typing import Dict, List, Tuple
from rapidocr_onnxruntime import RapidOCR
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import Entry
logger = logging.getLogger(__name__)
class ImageToEntries(TextToEntries):
def __init__(self):
super().__init__()
# Define Functions
def process(
self, files: dict[str, str] = None, full_corpus: bool = True, user: KhojUser = None, regenerate: bool = False
) -> Tuple[int, int]:
# Extract required fields from config
if not full_corpus:
deletion_file_names = set([file for file in files if files[file] == b""])
files_to_process = set(files) - deletion_file_names
files = {file: files[file] for file in files_to_process}
else:
deletion_file_names = None
# Extract Entries from specified image files
with timer("Extract entries from specified Image files", logger):
file_to_text_map, current_entries = ImageToEntries.extract_image_entries(files)
# Split entries by max tokens supported by model
with timer("Split entries by max token size supported by model", logger):
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=256)
# Identify, mark and merge any new entries with previous entries
with timer("Identify new or updated entries", logger):
num_new_embeddings, num_deleted_embeddings = self.update_embeddings(
current_entries,
DbEntry.EntryType.IMAGE,
DbEntry.EntrySource.COMPUTER,
"compiled",
logger,
deletion_file_names,
user,
regenerate=regenerate,
file_to_text_map=file_to_text_map,
)
return num_new_embeddings, num_deleted_embeddings
@staticmethod
def extract_image_entries(image_files) -> Tuple[Dict, List[Entry]]: # important function
"""Extract entries by page from specified image files"""
file_to_text_map = dict()
entries: List[str] = []
entry_to_location_map: List[Tuple[str, str]] = []
for image_file in image_files:
try:
loader = RapidOCR()
bytes = image_files[image_file]
# write the image to a temporary file
timestamp_now = datetime.utcnow().timestamp()
# use either png or jpg
if image_file.endswith(".png"):
tmp_file = f"tmp_image_file_{timestamp_now}.png"
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
with open(tmp_file, "wb") as f:
bytes = image_files[image_file]
f.write(bytes)
try:
image_entries_per_file = ""
result, _ = loader(tmp_file)
if result:
expanded_entries = [text[1] for text in result]
image_entries_per_file = " ".join(expanded_entries)
except ImportError:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
continue
entry_to_location_map.append((image_entries_per_file, image_file))
entries.extend([image_entries_per_file])
file_to_text_map[image_file] = image_entries_per_file
except Exception as e:
logger.warning(f"Unable to process file: {image_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
finally:
if os.path.exists(tmp_file):
os.remove(tmp_file)
return file_to_text_map, ImageToEntries.convert_image_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod
def convert_image_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
"Convert each image entries into a dictionary"
entries = []
for parsed_entry in parsed_entries:
entry_filename = entry_to_file_map[parsed_entry]
# Append base filename to compiled entry for context to model
heading = f"{entry_filename}\n"
compiled_entry = f"{heading}{parsed_entry}"
entries.append(
Entry(
compiled=compiled_entry,
raw=parsed_entry,
heading=heading,
file=f"{entry_filename}",
)
)
logger.debug(f"Converted {len(parsed_entries)} image entries to dictionaries")
return entries

View file

@ -146,7 +146,7 @@ class MarkdownToEntries(TextToEntries):
else:
entry_filename = str(Path(raw_filename))
heading = parsed_entry.splitlines()[0] if re.search("^#+\s", parsed_entry) else ""
heading = parsed_entry.splitlines()[0] if re.search(r"^#+\s", parsed_entry) else ""
# Append base filename to compiled entry for context to model
# Increment heading level for heading entries and make filename as its top level heading
prefix = f"# {entry_filename}\n#" if heading else f"# {entry_filename}\n"

View file

@ -115,14 +115,20 @@ class OrgToEntries(TextToEntries):
return entries, entry_to_file_map
# Split this entry tree into sections by the next heading level in it
# Increment heading level until able to split entry into sections
# Increment heading level until able to split entry into sections or reach max heading level
# A successful split will result in at least 2 sections
max_heading_level = 100
next_heading_level = len(ancestry)
sections: List[str] = []
while len(sections) < 2:
while len(sections) < 2 and next_heading_level < max_heading_level:
next_heading_level += 1
sections = re.split(rf"(\n|^)(?=[*]{{{next_heading_level}}} .+\n?)", org_content, flags=re.MULTILINE)
# If unable to split entry into sections, log error and skip indexing it
if next_heading_level == max_heading_level:
logger.error(f"Unable to split current entry chunk: {org_content_with_ancestry[:20]}. Skip indexing it.")
return entries, entry_to_file_map
# Recurse down each non-empty section after parsing its body, heading and ancestry
for section in sections:
# Skip empty sections
@ -135,7 +141,7 @@ class OrgToEntries(TextToEntries):
# If first non-empty line is a heading with expected heading level
if re.search(rf"^\*{{{next_heading_level}}}\s", first_non_empty_line):
# Extract the section body without the heading
current_section_body = "\n".join(section.split(first_non_empty_line)[1:])
current_section_body = "\n".join(section.split(first_non_empty_line, 1)[1:])
# Parse the section heading into current section ancestry
current_section_title = first_non_empty_line[next_heading_level:].strip()
current_ancestry[next_heading_level] = current_section_title

View file

@ -124,7 +124,7 @@ class TextToEntries(ABC):
deletion_filenames: Set[str] = None,
user: KhojUser = None,
regenerate: bool = False,
file_to_text_map: dict[str, List[str]] = None,
file_to_text_map: dict[str, str] = None,
):
with timer("Constructed current entry hashes in", logger):
hashes_by_file = dict[str, set[str]]()
@ -137,7 +137,7 @@ class TextToEntries(ABC):
if regenerate:
with timer("Cleared existing dataset for regeneration in", logger):
logger.debug(f"Deleting all entries for file type {file_type}")
num_deleted_entries = EntryAdapters.delete_all_entries_by_type(user, file_type)
num_deleted_entries = EntryAdapters.delete_all_entries(user, file_type=file_type)
hashes_to_process = set()
with timer("Identified entries to add to database in", logger):
@ -192,16 +192,17 @@ class TextToEntries(ABC):
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
if file_to_text_map:
# get the list of file_names using added_entries
filenames_to_update = [entry.file_path for entry in added_entries]
# for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object
for file_name in filenames_to_update:
raw_text = " ".join(file_to_text_map[file_name])
file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name)
if file_object:
FileObjectAdapters.update_raw_text(file_object, raw_text)
else:
FileObjectAdapters.create_file_object(user, file_name, raw_text)
with timer("Indexed text of modified file in", logger):
# get the set of modified files from added_entries
modified_files = {entry.file_path for entry in added_entries}
# create or update text of each updated file indexed on DB
for modified_file in modified_files:
raw_text = file_to_text_map[modified_file]
file_object = FileObjectAdapters.get_file_object_by_name(user, modified_file)
if file_object:
FileObjectAdapters.update_raw_text(file_object, raw_text)
else:
FileObjectAdapters.create_file_object(user, modified_file, raw_text)
new_dates = []
with timer("Indexed dates from added entries in", logger):

View file

@ -99,15 +99,13 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]
max_prompt_size = max_prompt_size or DEFAULT_MAX_TOKENS_ANTHROPIC
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
temperature=temperature,
system=system_prompt,
timeout=20,
max_tokens=max_prompt_size,
max_tokens=DEFAULT_MAX_TOKENS_ANTHROPIC,
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:

View file

@ -154,7 +154,7 @@ def converse(
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
if not is_none_or_empty(online_results):
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)

View file

@ -1,5 +1,4 @@
import logging
import os
from threading import Thread
from typing import Dict
@ -40,7 +39,7 @@ def completion_with_backoff(
client: openai.OpenAI = openai_clients.get(client_key)
if not client:
client = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client
@ -102,7 +101,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI(
api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client

View file

@ -121,7 +121,7 @@ User's Notes:
## Image Generation
## --
image_generation_improve_prompt = PromptTemplate.from_template(
image_generation_improve_prompt_dalle = PromptTemplate.from_template(
"""
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information originally from the query. You are provided with the following information to help you generate the prompt:
@ -143,6 +143,35 @@ Remember, now you are generating a prompt to improve the image generation. Add a
Improved Query:"""
)
image_generation_improve_prompt_sd = PromptTemplate.from_template(
"""
You are a talented creator. Write 2-5 sentences with precise image composition, position details to create an image.
Use the provided context below to add specific, fine details to the image composition.
Retain any important information and follow any instructions from the original prompt.
Put any text to be rendered in the image within double quotes in your improved prompt.
You are provided with the following context to help enhance the original prompt:
Today's Date: {current_date}
User's Location: {location}
User's Notes:
{references}
Online References:
{online_results}
Conversation Log:
{chat_history}
Original Prompt: "{query}"
Now create an improved prompt using the context provided above to generate an image.
Retain any important information and follow any instructions from the original prompt.
Use the additional context from the user's notes, online references and conversation log to improve the image generation.
Improved Prompt:"""
)
## Online Search Conversation
## --
online_search_conversation = PromptTemplate.from_template(

View file

@ -2,11 +2,11 @@ import asyncio
import json
import logging
import os
import urllib.parse
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union
import aiohttp
import requests
from bs4 import BeautifulSoup
from markdownify import markdownify
@ -23,6 +23,10 @@ logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
SERPER_DEV_URL = "https://google.serper.dev/search"
JINA_READER_API_URL = "https://r.jina.ai/"
JINA_SEARCH_API_URL = "https://s.jina.ai/"
JINA_API_KEY = os.getenv("JINA_API_KEY")
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
OLOSTEP_QUERY_PARAMS = {
@ -50,9 +54,6 @@ async def search_online(
custom_filters: List[str] = [],
):
query += " ".join(custom_filters)
if not online_search_enabled():
logger.warn("SERPER_DEV_API_KEY is not set")
return {}
if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet")
return {}
@ -61,27 +62,35 @@ async def search_online(
subqueries = await generate_online_subqueries(query, conversation_history, location)
response_dict = {}
for subquery in subqueries:
if subqueries:
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
await send_status_func(f"**🌐 Searching the Internet for**: {subquery}")
logger.info(f"🌐 Searching the Internet for '{subquery}'")
response_dict[subquery] = search_with_google(subquery)
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}")
# Gather distinct web pages from organic search results of each subquery without an instant answer
webpage_links = {
organic["link"]: subquery
with timer(f"Internet searches for {list(subqueries)} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results}
# Gather distinct web page data from organic results of each subquery without an instant answer.
# Content of web pages is directly available when Jina is used for search.
webpages = {
(organic.get("link"), subquery, organic.get("content"))
for subquery in response_dict
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
if "answerBox" not in response_dict[subquery]
}
# Read, extract relevant info from the retrieved web pages
if webpage_links:
if webpages:
webpage_links = [link for link, _, _ in webpages]
logger.info(f"🌐👀 Reading web pages at: {list(webpage_links)}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
await send_status_func(f"**📖 Reading web pages**: {webpage_links_str}")
tasks = [read_webpage_and_extract_content(subquery, link) for link, subquery in webpage_links.items()]
tasks = [read_webpage_and_extract_content(subquery, link, content) for link, subquery, content in webpages]
results = await asyncio.gather(*tasks)
# Collect extracted info from the retrieved web pages
@ -92,23 +101,24 @@ async def search_online(
return response_dict
def search_with_google(subquery: str):
payload = json.dumps({"q": subquery})
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
payload = json.dumps({"q": query})
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
async with aiohttp.ClientSession() as session:
async with session.post(SERPER_DEV_URL, headers=headers, data=payload) as response:
if response.status != 200:
logger.error(await response.text())
return query, {}
json_response = await response.json()
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
extracted_search_result = {
field: json_response[field]
for field in extraction_fields
if not is_none_or_empty(json_response.get(field))
}
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
extracted_search_result = {
field: json_response[field] for field in extraction_fields if not is_none_or_empty(json_response.get(field))
}
return extracted_search_result
return query, extracted_search_result
async def read_webpages(
@ -134,10 +144,13 @@ async def read_webpages(
return response
async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str], str]:
async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None
) -> Tuple[str, Union[None, str], str]:
try:
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_at_url(url)
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content)
return subquery, extracted_info, url
@ -172,5 +185,41 @@ async def read_webpage_with_olostep(web_url: str) -> str:
return response_json["markdown_content"]
def online_search_enabled():
return SERPER_DEV_API_KEY is not None
async def read_webpage_with_jina(web_url: str) -> str:
jina_reader_api_url = f"{JINA_READER_API_URL}/{web_url}"
headers = {"Accept": "application/json", "X-Timeout": "30"}
if JINA_API_KEY:
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
async with aiohttp.ClientSession() as session:
async with session.get(jina_reader_api_url, headers=headers) as response:
response.raise_for_status()
response_json = await response.json()
return response_json["data"]["content"]
async def search_with_jina(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
encoded_query = urllib.parse.quote(query)
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
headers = {"Accept": "application/json"}
if JINA_API_KEY:
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
async with aiohttp.ClientSession() as session:
async with session.get(jina_search_api_url, headers=headers) as response:
if response.status != 200:
logger.error(await response.text())
return query, {}
response_json = await response.json()
parsed_response = [
{
"title": item["title"],
"content": item.get("content"),
# rename description -> snippet for consistency
"snippet": item["description"],
# rename url -> link for consistency
"link": item["url"],
}
for item in response_json["data"]
]
return query, {"organic": parsed_response}

View file

@ -13,6 +13,7 @@ from starlette.authentication import requires
from starlette.websockets import WebSocketDisconnect
from websockets import ConnectionClosedOK
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
@ -28,11 +29,7 @@ from khoj.processor.conversation.prompts import (
)
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import (
online_search_enabled,
read_webpages,
search_online,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
@ -153,7 +150,17 @@ async def sendfeedback(request: Request, data: FeedbackData):
@api_chat.post("/speech")
@requires(["authenticated", "premium"])
async def text_to_speech(request: Request, common: CommonQueryParams, text: str):
async def text_to_speech(
request: Request,
common: CommonQueryParams,
text: str,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=20, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object)
params = {"text_to_speak": text}
@ -350,17 +357,19 @@ def duplicate_chat_history_public_conversation(
conversation_id: int,
):
user = request.user.object
domain = request.headers.get("host")
scheme = request.url.scheme
# Throw unauthorized exception if domain not in ALLOWED_HOSTS
host_domain = domain.split(":")[0]
if host_domain not in ALLOWED_HOSTS:
raise HTTPException(status_code=401, detail="Unauthorized domain")
# Duplicate Conversation History to Public Conversation
conversation = ConversationAdapters.get_conversation_by_user(user, request.user.client_app, conversation_id)
public_conversation = ConversationAdapters.make_public_conversation_copy(conversation)
public_conversation_url = PublicConversationAdapters.get_public_conversation_url(public_conversation)
domain = request.headers.get("host")
scheme = request.url.scheme
update_telemetry_state(
request=request,
telemetry_type="api",
@ -610,6 +619,7 @@ async def websocket_endpoint(
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(q, meta_log, is_automated_task)
@ -625,8 +635,18 @@ async def websocket_endpoint(
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
if ConversationCommand.Summarize in conversation_commands:
file_filters = conversation.file_filters
file_filters = conversation.file_filters if conversation else []
# Skip trying to summarize if
if (
# summarization intent was inferred
ConversationCommand.Summarize in conversation_commands
# and not triggered via slash command
and not used_slash_summarize
# but we can't actually summarize
and len(file_filters) != 1
):
conversation_commands.remove(ConversationCommand.Summarize)
elif ConversationCommand.Summarize in conversation_commands:
response_log = ""
if len(file_filters) == 0:
response_log = "No files selected for summarization. Please add files using the section on the left."
@ -741,22 +761,16 @@ async def websocket_endpoint(
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
if not online_search_enabled():
conversation_commands.remove(ConversationCommand.Online)
# If online search is not enabled, try to read webpages directly
if ConversationCommand.Webpage not in conversation_commands:
conversation_commands.append(ConversationCommand.Webpage)
else:
try:
online_results = await search_online(
defiltered_query, meta_log, location, send_status_update, custom_filters
)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
await send_complete_llm_response(
f"Error searching online: {e}. Attempting to respond without online results"
)
continue
try:
online_results = await search_online(
defiltered_query, meta_log, location, send_status_update, custom_filters
)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
await send_complete_llm_response(
f"Error searching online: {e}. Attempting to respond without online results"
)
continue
if ConversationCommand.Webpage in conversation_commands:
try:
@ -1041,18 +1055,10 @@ async def chat(
conversation_commands.remove(ConversationCommand.Notes)
if ConversationCommand.Online in conversation_commands:
if not online_search_enabled():
conversation_commands.remove(ConversationCommand.Online)
# If online search is not enabled, try to read webpages directly
if ConversationCommand.Webpage not in conversation_commands:
conversation_commands.append(ConversationCommand.Webpage)
else:
try:
online_results = await search_online(
defiltered_query, meta_log, location, custom_filters=_custom_filters
)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
try:
online_results = await search_online(defiltered_query, meta_log, location, custom_filters=_custom_filters)
except ValueError as e:
logger.warning(f"Error searching online: {e}. Attempting to respond without online results")
if ConversationCommand.Webpage in conversation_commands:
try:

View file

@ -183,7 +183,7 @@ async def remove_content_source_data(
raise ValueError(f"Invalid content source: {content_source}")
elif content_object != "Computer":
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
await sync_to_async(EntryAdapters.delete_all_entries)(user, file_source=content_source)
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@ -341,6 +341,35 @@ async def update_search_model(
return {"status": "ok"}
@api_config.post("/data/paint/model", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
subscribed = has_required_scope(request, ["premium"])
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_paint_model",
client=client,
metadata={"paint_model": new_config.setting.model_name},
)
if new_config is None:
return {"status": "error", "message": "Model not found"}
return {"status": "ok"}
@api_config.get("/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):

View file

@ -42,8 +42,12 @@ if not state.anonymous_mode:
from google.oauth2 import id_token
except ImportError:
missing_requirements += ["Install the Khoj production package with `pip install khoj-assistant[prod]`"]
if not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET"):
missing_requirements += ["Set your GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET as environment variables"]
if not os.environ.get("RESEND_API_KEY") and (
not os.environ.get("GOOGLE_CLIENT_ID") or not os.environ.get("GOOGLE_CLIENT_SECRET")
):
missing_requirements += [
"Set your RESEND_API_KEY or GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET as environment variables"
]
if missing_requirements:
requirements_string = "\n - " + "\n - ".join(missing_requirements)
error_msg = f"🚨 Start Khoj with --anonymous-mode flag or to enable authentication:{requirements_string}"

View file

@ -453,12 +453,14 @@ async def generate_better_image_prompt(
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
) -> str:
"""
Generate a better image prompt from the given query
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
@ -477,21 +479,34 @@ async def generate_better_image_prompt(
elif online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
chat_history=conversation_history,
location=location_prompt,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
)
if model_type == TextToImageModelConfig.ModelType.OPENAI:
image_prompt = prompts.image_generation_improve_prompt_dalle.format(
query=q,
chat_history=conversation_history,
location=location_prompt,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
)
elif model_type == TextToImageModelConfig.ModelType.STABILITYAI:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
query=q,
chat_history=conversation_history,
location=location_prompt,
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
)
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, chat_model_option=summarizer_model)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
return response.strip()
return response
async def send_message_to_model_wrapper(
@ -747,74 +762,110 @@ async def text_to_image(
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Query: {chat['intent']['query']}\n"
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
try:
with timer("Improve the original user query", logger):
if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
)
with timer("Generate image with OpenAI", logger):
if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Query: {chat['intent']['query']}\n"
chat_history += f"A: Improved Query: {chat['intent']['inferred-queries'][0]}\n"
with timer("Improve the original user query", logger):
if send_status_func:
await send_status_func("**✍🏽 Enhancing the Painting Prompt**")
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
)
if send_status_func:
await send_status_func(f"**🖼️ Painting using Enhanced Prompt**:\n{improved_image_prompt}")
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
prompt=improved_image_prompt,
model=text2image_model,
response_format="b64_json",
extra_headers=auth_header,
)
image = response.data[0].b64_json
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
decoded_image = base64.b64decode(image)
image_io = io.BytesIO(decoded_image)
png_image = Image.open(image_io)
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
image_io.close()
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image_url or image, status_code, message, intent_type.value
else:
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
with timer("Generate image with Stability AI", logger):
try:
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"seed": 1032622926,
"aspect_ratio": "1:1",
},
)
decoded_image = response.content
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
message = f"Image generation failed with Stability AI error: {e}"
status_code = e.status_code # type: ignore
return image_url or image, status_code, message, intent_type.value
return image_url or image, status_code, response, intent_type.value
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
image_io = io.BytesIO(decoded_image)
png_image = Image.open(image_io)
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
image_io.close()
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
class ApiUserRateLimiter:

View file

@ -9,6 +9,7 @@ from starlette.authentication import requires
from khoj.database.models import GithubConfig, KhojUser, NotionConfig
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.github.github_to_entries import GithubToEntries
from khoj.processor.content.images.image_to_entries import ImageToEntries
from khoj.processor.content.markdown.markdown_to_entries import MarkdownToEntries
from khoj.processor.content.notion.notion_to_entries import NotionToEntries
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
@ -41,6 +42,7 @@ class IndexerInput(BaseModel):
markdown: Optional[dict[str, str]] = None
pdf: Optional[dict[str, bytes]] = None
plaintext: Optional[dict[str, str]] = None
image: Optional[dict[str, bytes]] = None
docx: Optional[dict[str, bytes]] = None
@ -65,7 +67,14 @@ async def update(
),
):
user = request.user.object
index_files: Dict[str, Dict[str, str]] = {"org": {}, "markdown": {}, "pdf": {}, "plaintext": {}, "docx": {}}
index_files: Dict[str, Dict[str, str]] = {
"org": {},
"markdown": {},
"pdf": {},
"plaintext": {},
"image": {},
"docx": {},
}
try:
logger.info(f"📬 Updating content index via API call by {client} client")
for file in files:
@ -81,6 +90,7 @@ async def update(
markdown=index_files["markdown"],
pdf=index_files["pdf"],
plaintext=index_files["plaintext"],
image=index_files["image"],
docx=index_files["docx"],
)
@ -133,6 +143,7 @@ async def update(
"num_markdown": len(index_files["markdown"]),
"num_pdf": len(index_files["pdf"]),
"num_plaintext": len(index_files["plaintext"]),
"num_image": len(index_files["image"]),
"num_docx": len(index_files["docx"]),
}
@ -300,6 +311,23 @@ def configure_content(
logger.error(f"🚨 Failed to setup Notion: {e}", exc_info=True)
success = False
try:
# Initialize Image Search
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value) and files[
"image"
]:
logger.info("🖼️ Setting up search for images")
# Extract Entries, Generate Image Embeddings
text_search.setup(
ImageToEntries,
files.get("image"),
regenerate=regenerate,
full_corpus=full_corpus,
user=user,
)
except Exception as e:
logger.error(f"🚨 Failed to setup images: {e}", exc_info=True)
success = False
try:
if (search_type == state.SearchType.All.value or search_type == state.SearchType.Docx.value) and files["docx"]:
logger.info("📄 Setting up search for docx")

View file

@ -261,6 +261,12 @@ def config_page(request: Request):
current_search_model_option = adapters.get_user_search_model_or_default(user)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
all_paint_model_options.append({"model_name": paint_model.model_name, "id": paint_model.id})
notion_oauth_url = get_notion_auth_url(user)
eleven_labs_enabled = is_eleven_labs_enabled()
@ -283,10 +289,12 @@ def config_page(request: Request):
"anonymous_mode": state.anonymous_mode,
"username": user.username,
"given_name": given_name,
"conversation_options": all_conversation_options,
"search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id,
"conversation_options": all_conversation_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"paint_model_options": all_paint_model_options,
"selected_paint_model_config": selected_paint_model_config.id if selected_paint_model_config else None,
"user_photo": user_picture,
"billing_enabled": state.billing_enabled,
"subscription_state": user_subscription_state,

View file

@ -118,9 +118,9 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
elif file_type in ["application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"]:
return "docx", encoding
elif file_type in ["image/jpeg"]:
return "jpeg", encoding
return "image", encoding
elif file_type in ["image/png"]:
return "png", encoding
return "image", encoding
elif content_group in ["code", "text"]:
return "plaintext", encoding
else:

View file

@ -70,6 +70,7 @@ class ContentConfig(ConfigBase):
plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] = None
image: Optional[TextContentConfig] = None
docx: Optional[TextContentConfig] = None

BIN
tests/data/images/nasdaq.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
tests/data/images/testocr.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

View file

@ -0,0 +1,21 @@
import os
from khoj.processor.content.images.image_to_entries import ImageToEntries
def test_png_to_jsonl():
with open("tests/data/images/testocr.png", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/testocr.png": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "opencv-python" in entries[1][0].raw
def test_jpg_to_jsonl():
with open("tests/data/images/nasdaq.jpg", "rb") as f:
image_bytes = f.read()
data = {"tests/data/images/nasdaq.jpg": image_bytes}
entries = ImageToEntries.extract_image_entries(image_files=data)
assert len(entries) == 2
assert "investments" in entries[1][0].raw

View file

@ -62,7 +62,6 @@ def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_c
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(client_offline_chat):
@ -75,18 +74,18 @@ def test_chat_with_online_content(client_offline_chat):
response_message = response_message.split("### compiled references")[0]
# Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
expected_responses = [
"https://paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html",
"http://www.paulgraham.com/greatwork.html",
]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected links or serper not setup in response but got: " + response_message
)
assert any(
[expected_response in response_message for expected_response in expected_responses]
), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(
os.getenv("SERPER_DEV_API_KEY") is None or os.getenv("OLOSTEP_API_KEY") is None,
reason="requires SERPER_DEV_API_KEY and OLOSTEP_API_KEY",
)
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_webpage_content(client_offline_chat):
@ -101,9 +100,9 @@ def test_chat_with_online_webpage_content(client_offline_chat):
# Assert
expected_responses = ["185", "1871", "horse"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected links or serper not setup in response but got: " + response_message
)
assert any(
[expected_response in response_message for expected_response in expected_responses]
), f"Expected response with {expected_responses}. But actual response had: {response_message}"
# ----------------------------------------------------------------------------------------------------

View file

@ -61,7 +61,6 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(os.getenv("SERPER_DEV_API_KEY") is None, reason="requires SERPER_DEV_API_KEY")
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_content(chat_client):
@ -74,18 +73,18 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0]
# Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
expected_responses = [
"https://paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html",
"http://www.paulgraham.com/greatwork.html",
]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected links or serper not setup in response but got: " + response_message
)
assert any(
[expected_response in response_message for expected_response in expected_responses]
), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.skipif(
os.getenv("SERPER_DEV_API_KEY") is None or os.getenv("OLOSTEP_API_KEY") is None,
reason="requires SERPER_DEV_API_KEY and OLOSTEP_API_KEY",
)
@pytest.mark.chatquality
@pytest.mark.django_db(transaction=True)
def test_chat_with_online_webpage_content(chat_client):
@ -100,9 +99,9 @@ def test_chat_with_online_webpage_content(chat_client):
# Assert
expected_responses = ["185", "1871", "horse"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected links or serper not setup in response but got: " + response_message
)
assert any(
[expected_response in response_message for expected_response in expected_responses]
), f"Expected links: {expected_responses}. Actual response: {response_message}"
# ----------------------------------------------------------------------------------------------------

View file

@ -1,5 +1,6 @@
import os
import re
import time
from khoj.processor.content.org_mode.org_to_entries import OrgToEntries
from khoj.processor.content.text_to_entries import TextToEntries
@ -41,6 +42,35 @@ def test_configure_indexing_heading_only_entries(tmp_path):
assert is_none_or_empty(entries[1])
def test_extract_entries_when_child_headings_have_same_prefix():
"""Extract org entries from entries having child headings with same prefix.
Prevents regressions like the one fixed in PR #840.
"""
# Arrange
tmp_path = "tests/data/org/same_prefix_headings.org"
entry: str = """
** 1
*** 1.1
**** 1.1.2
""".strip()
data = {
f"{tmp_path}": entry,
}
# Act
# Extract Entries from specified Org files
start = time.time()
entries = OrgToEntries.extract_org_entries(org_files=data, max_tokens=2)
end = time.time()
indexing_time = end - start
# Assert
explanation_msg = (
"It should not take more than 6 seconds to index. Entry extraction may have gone into an infinite loop."
)
assert indexing_time < 6 * len(entries), explanation_msg
def test_entry_split_when_exceeds_max_tokens():
"Ensure entries with compiled words exceeding max_tokens are split."
# Arrange

View file

@ -52,5 +52,6 @@
"1.12.1": "0.15.0",
"1.13.0": "0.15.0",
"1.14.0": "0.15.0",
"1.15.0": "0.15.0"
"1.15.0": "0.15.0",
"1.16.0": "0.15.0"
}