mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Resolve merge conflicts in auth.py with remove KhojApiUser import
This commit is contained in:
commit
6b17aeb32d
35 changed files with 605 additions and 266 deletions
|
@ -10,7 +10,15 @@ services:
|
||||||
POSTGRES_DB: postgres
|
POSTGRES_DB: postgres
|
||||||
volumes:
|
volumes:
|
||||||
- khoj_db:/var/lib/postgresql/data/
|
- khoj_db:/var/lib/postgresql/data/
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 5
|
||||||
server:
|
server:
|
||||||
|
depends_on:
|
||||||
|
database:
|
||||||
|
condition: service_healthy
|
||||||
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
|
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
|
||||||
image: ghcr.io/khoj-ai/khoj:latest
|
image: ghcr.io/khoj-ai/khoj:latest
|
||||||
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
|
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
|
||||||
|
@ -24,20 +32,6 @@ services:
|
||||||
- "42110:42110"
|
- "42110:42110"
|
||||||
working_dir: /app
|
working_dir: /app
|
||||||
volumes:
|
volumes:
|
||||||
- .:/app
|
|
||||||
# These mounted volumes hold the raw data that should be indexed for search.
|
|
||||||
# The path in your local directory (left hand side)
|
|
||||||
# points to the files you want to index.
|
|
||||||
# The path of the mounted directory (right hand side),
|
|
||||||
# must match the path prefix in your config file.
|
|
||||||
- ./tests/data/org/:/data/org/
|
|
||||||
- ./tests/data/images/:/data/images/
|
|
||||||
- ./tests/data/markdown/:/data/markdown/
|
|
||||||
- ./tests/data/pdf/:/data/pdf/
|
|
||||||
# Embeddings and models are populated after the first run
|
|
||||||
# You can set these volumes to point to empty directories on host
|
|
||||||
- ./tests/data/embeddings/:/root/.khoj/content/
|
|
||||||
- ./tests/data/models/:/root/.khoj/search/
|
|
||||||
- khoj_config:/root/.khoj/
|
- khoj_config:/root/.khoj/
|
||||||
- khoj_models:/root/.cache/torch/sentence_transformers
|
- khoj_models:/root/.cache/torch/sentence_transformers
|
||||||
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
||||||
|
@ -47,9 +41,11 @@ services:
|
||||||
- POSTGRES_PASSWORD=postgres
|
- POSTGRES_PASSWORD=postgres
|
||||||
- POSTGRES_HOST=database
|
- POSTGRES_HOST=database
|
||||||
- POSTGRES_PORT=5432
|
- POSTGRES_PORT=5432
|
||||||
- GOOGLE_CLIENT_SECRET=bar
|
- KHOJ_DJANGO_SECRET_KEY=secret
|
||||||
- GOOGLE_CLIENT_ID=foo
|
- KHOJ_DEBUG=True
|
||||||
command: --host="0.0.0.0" --port=42110 -vv
|
- KHOJ_ADMIN_EMAIL=username@example.com
|
||||||
|
- KHOJ_ADMIN_PASSWORD=password
|
||||||
|
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode
|
||||||
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import math
|
||||||
from typing import Optional, Type, TypeVar, List
|
from typing import Optional, Type, TypeVar, List
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
import secrets
|
import secrets
|
||||||
|
@ -99,6 +100,8 @@ async def create_google_user(token: dict) -> KhojUser:
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await Subscription.objects.acreate(user=user, type="trial")
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,12 +436,19 @@ class EntryAdapters:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_with_embeddings(
|
def search_with_embeddings(
|
||||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
user: KhojUser,
|
||||||
|
embeddings: Tensor,
|
||||||
|
max_results: int = 10,
|
||||||
|
file_type_filter: str = None,
|
||||||
|
raw_query: str = None,
|
||||||
|
max_distance: float = math.inf,
|
||||||
):
|
):
|
||||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||||
distance=CosineDistance("embeddings", embeddings)
|
distance=CosineDistance("embeddings", embeddings)
|
||||||
)
|
)
|
||||||
|
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||||
|
|
||||||
if file_type_filter:
|
if file_type_filter:
|
||||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||||
relevant_entries = relevant_entries.order_by("distance")
|
relevant_entries = relevant_entries.order_by("distance")
|
||||||
|
|
|
@ -8,6 +8,7 @@ from database.models import (
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
OfflineChatProcessorConversationConfig,
|
OfflineChatProcessorConversationConfig,
|
||||||
|
Subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
admin.site.register(KhojUser, UserAdmin)
|
admin.site.register(KhojUser, UserAdmin)
|
||||||
|
@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
|
||||||
admin.site.register(ChatModelOptions)
|
admin.site.register(ChatModelOptions)
|
||||||
admin.site.register(OpenAIProcessorConversationConfig)
|
admin.site.register(OpenAIProcessorConversationConfig)
|
||||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||||
|
admin.site.register(Subscription)
|
||||||
|
|
21
src/database/migrations/0015_alter_subscription_user.py
Normal file
21
src/database/migrations/0015_alter_subscription_user.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-11-11 05:39
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0014_alter_googleuser_picture"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="subscription",
|
||||||
|
name="user",
|
||||||
|
field=models.OneToOneField(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Generated by Django 4.2.5 on 2023-11-11 06:15
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0015_alter_subscription_user"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="subscription",
|
||||||
|
name="renewal_date",
|
||||||
|
field=models.DateTimeField(blank=True, default=None, null=True),
|
||||||
|
),
|
||||||
|
]
|
|
@ -51,10 +51,10 @@ class Subscription(BaseModel):
|
||||||
TRIAL = "trial"
|
TRIAL = "trial"
|
||||||
STANDARD = "standard"
|
STANDARD = "standard"
|
||||||
|
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
|
||||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||||
is_recurring = models.BooleanField(default=False)
|
is_recurring = models.BooleanField(default=False)
|
||||||
renewal_date = models.DateTimeField(null=True, default=None)
|
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class NotionConfig(BaseModel):
|
class NotionConfig(BaseModel):
|
||||||
|
|
|
@ -577,12 +577,12 @@
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
transition: background 0.2s ease-in-out;
|
transition: background 0.2s ease-in-out;
|
||||||
text-align: left;
|
text-align: left;
|
||||||
max-height: 50px;
|
max-height: 75px;
|
||||||
transition: max-height 0.3s ease-in-out;
|
transition: max-height 0.3s ease-in-out;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
button.reference-button.expanded {
|
button.reference-button.expanded {
|
||||||
max-height: 200px;
|
max-height: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
button.reference-button::before {
|
button.reference-button::before {
|
||||||
|
|
|
@ -198,12 +198,6 @@ khojKeyInput.addEventListener('blur', async () => {
|
||||||
khojKeyInput.value = token;
|
khojKeyInput.value = token;
|
||||||
});
|
});
|
||||||
|
|
||||||
const syncButton = document.getElementById('sync-data');
|
|
||||||
syncButton.addEventListener('click', async () => {
|
|
||||||
loadingBar.style.display = 'block';
|
|
||||||
await window.syncDataAPI.syncData(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
const syncForceButton = document.getElementById('sync-force');
|
const syncForceButton = document.getElementById('sync-force');
|
||||||
syncForceButton.addEventListener('click', async () => {
|
syncForceButton.addEventListener('click', async () => {
|
||||||
loadingBar.style.display = 'block';
|
loadingBar.style.display = 'block';
|
||||||
|
|
|
@ -188,7 +188,6 @@
|
||||||
fetch(url, { headers })
|
fetch(url, { headers })
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log(data);
|
|
||||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { Notice, Plugin } from 'obsidian';
|
import { Notice, Plugin, request } from 'obsidian';
|
||||||
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
|
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
|
||||||
import { KhojSearchModal } from 'src/search_modal'
|
import { KhojSearchModal } from 'src/search_modal'
|
||||||
import { KhojChatModal } from 'src/chat_modal'
|
import { KhojChatModal } from 'src/chat_modal'
|
||||||
|
@ -69,6 +69,25 @@ export default class Khoj extends Plugin {
|
||||||
async loadSettings() {
|
async loadSettings() {
|
||||||
// Load khoj obsidian plugin settings
|
// Load khoj obsidian plugin settings
|
||||||
this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData());
|
this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData());
|
||||||
|
|
||||||
|
// Check if khoj backend is configured, note if cannot connect to backend
|
||||||
|
let headers = { "Authorization": `Bearer ${this.settings.khojApiKey}` };
|
||||||
|
|
||||||
|
if (this.settings.khojUrl === "https://app.khoj.dev") {
|
||||||
|
if (this.settings.khojApiKey === "") {
|
||||||
|
new Notice(`❗️Khoj API key is not configured. Please visit https://app.khoj.dev to get an API key.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
await request({ url: this.settings.khojUrl ,method: "GET", headers: headers })
|
||||||
|
.then(response => {
|
||||||
|
this.settings.connectedToBackend = true;
|
||||||
|
})
|
||||||
|
.catch(error => {
|
||||||
|
this.settings.connectedToBackend = false;
|
||||||
|
new Notice(`❗️Ensure Khoj backend is running and Khoj URL is pointing to it in the plugin settings.\n\n${error}`);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async saveSettings() {
|
async saveSettings() {
|
||||||
|
|
|
@ -28,7 +28,7 @@ from khoj.utils.config import (
|
||||||
from khoj.utils.fs_syncer import collect_files
|
from khoj.utils.fs_syncer import collect_files
|
||||||
from khoj.utils.rawconfig import FullConfig
|
from khoj.utils.rawconfig import FullConfig
|
||||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser, Subscription
|
||||||
from database.adapters import get_all_users
|
from database.adapters import get_all_users
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,27 +54,37 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
|
|
||||||
def _initialize_default_user(self):
|
def _initialize_default_user(self):
|
||||||
if not self.khojuser_manager.filter(username="default").exists():
|
if not self.khojuser_manager.filter(username="default").exists():
|
||||||
self.khojuser_manager.create_user(
|
default_user = self.khojuser_manager.create_user(
|
||||||
username="default",
|
username="default",
|
||||||
email="default@example.com",
|
email="default@example.com",
|
||||||
password="default",
|
password="default",
|
||||||
)
|
)
|
||||||
|
Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01")
|
||||||
|
|
||||||
async def authenticate(self, request: HTTPConnection):
|
async def authenticate(self, request: HTTPConnection):
|
||||||
current_user = request.session.get("user")
|
current_user = request.session.get("user")
|
||||||
if current_user and current_user.get("email"):
|
if current_user and current_user.get("email"):
|
||||||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
user = (
|
||||||
|
await self.khojuser_manager.filter(email=current_user.get("email"))
|
||||||
|
.prefetch_related("subscription")
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
if user:
|
if user:
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||||
# Get bearer token from header
|
# Get bearer token from header
|
||||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||||
# Get user owning token
|
# Get user owning token
|
||||||
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst()
|
user_with_token = (
|
||||||
|
await self.khojapiuser_manager.filter(token=bearer_token)
|
||||||
|
.select_related("user")
|
||||||
|
.prefetch_related("user__subscription")
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
if user_with_token:
|
if user_with_token:
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||||
if state.anonymous_mode:
|
if state.anonymous_mode:
|
||||||
user = await self.khojuser_manager.filter(username="default").afirst()
|
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||||
if user:
|
if user:
|
||||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-rows: repeat(3, 1fr);
|
grid-template-rows: repeat(3, 1fr);
|
||||||
gap: 8px;
|
gap: 8px;
|
||||||
padding: 24px 16px;
|
padding: 24px 16px 8px;
|
||||||
width: 320px;
|
width: 320px;
|
||||||
height: 180px;
|
height: 180px;
|
||||||
background: var(--background-color);
|
background: var(--background-color);
|
||||||
|
@ -121,7 +121,7 @@
|
||||||
div.finalize-buttons {
|
div.finalize-buttons {
|
||||||
display: grid;
|
display: grid;
|
||||||
gap: 8px;
|
gap: 8px;
|
||||||
padding: 24px 16px;
|
padding: 32px 0px 0px;
|
||||||
width: 320px;
|
width: 320px;
|
||||||
border-radius: 4px;
|
border-radius: 4px;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
|
@ -162,10 +162,13 @@
|
||||||
color: grey;
|
color: grey;
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
}
|
}
|
||||||
.card-button-row {
|
.card-description-row {
|
||||||
|
padding-top: 4px;
|
||||||
|
}
|
||||||
|
.card-action-row {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-columns: auto;
|
grid-auto-flow: row;
|
||||||
text-align: right;
|
justify-content: left;
|
||||||
}
|
}
|
||||||
.card-button {
|
.card-button {
|
||||||
border: none;
|
border: none;
|
||||||
|
@ -271,7 +274,9 @@
|
||||||
100% { transform: rotate(360deg); }
|
100% { transform: rotate(360deg); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#status {
|
||||||
|
padding-top: 32px;
|
||||||
|
}
|
||||||
div.finalize-actions {
|
div.finalize-actions {
|
||||||
grid-auto-flow: column;
|
grid-auto-flow: column;
|
||||||
grid-gap: 24px;
|
grid-gap: 24px;
|
||||||
|
@ -287,6 +292,7 @@
|
||||||
|
|
||||||
select#chat-models {
|
select#chat-models {
|
||||||
margin-bottom: 0;
|
margin-bottom: 0;
|
||||||
|
padding: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -343,6 +349,12 @@
|
||||||
width: auto;
|
width: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#status {
|
||||||
|
padding-top: 12px;
|
||||||
|
}
|
||||||
|
div.finalize-actions {
|
||||||
|
padding: 12px 0 0;
|
||||||
|
}
|
||||||
div.finalize-buttons {
|
div.finalize-buttons {
|
||||||
padding: 0;
|
padding: 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
let escaped_ref = reference.replaceAll('"', '"');
|
let escaped_ref = reference.replaceAll('"', '"');
|
||||||
|
|
||||||
// Generate HTML for Chat Reference
|
// Generate HTML for Chat Reference
|
||||||
let short_ref = escaped_ref.slice(0, 100);
|
let short_ref = escaped_ref.slice(0, 140);
|
||||||
short_ref = short_ref.length < escaped_ref.length ? short_ref + "..." : short_ref;
|
short_ref = short_ref.length < escaped_ref.length ? short_ref + "..." : short_ref;
|
||||||
let referenceButton = document.createElement('button');
|
let referenceButton = document.createElement('button');
|
||||||
referenceButton.innerHTML = short_ref;
|
referenceButton.innerHTML = short_ref;
|
||||||
|
@ -205,8 +205,11 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
|
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
|
||||||
const currentHTML = newResponseText.innerHTML;
|
const currentHTML = newResponseText.innerHTML;
|
||||||
newResponseText.innerHTML = formatHTMLMessage(currentHTML);
|
newResponseText.innerHTML = formatHTMLMessage(currentHTML);
|
||||||
newResponseText.appendChild(references);
|
if (references != null) {
|
||||||
|
newResponseText.appendChild(references);
|
||||||
|
}
|
||||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||||
|
document.getElementById("chat-input").removeAttribute("disabled");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,7 +268,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
readStream();
|
readStream();
|
||||||
document.getElementById("chat-input").removeAttribute("disabled");
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -417,6 +419,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
display: block;
|
display: block;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.references {
|
||||||
|
padding-top: 8px;
|
||||||
|
}
|
||||||
div.reference {
|
div.reference {
|
||||||
display: grid;
|
display: grid;
|
||||||
grid-template-rows: auto;
|
grid-template-rows: auto;
|
||||||
|
@ -447,12 +452,12 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
transition: background 0.2s ease-in-out;
|
transition: background 0.2s ease-in-out;
|
||||||
text-align: left;
|
text-align: left;
|
||||||
max-height: 50px;
|
max-height: 75px;
|
||||||
transition: max-height 0.3s ease-in-out;
|
transition: max-height 0.3s ease-in-out;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
button.reference-button.expanded {
|
button.reference-button.expanded {
|
||||||
max-height: 200px;
|
max-height: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
button.reference-button::before {
|
button.reference-button::before {
|
||||||
|
|
|
@ -29,12 +29,12 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
<div id="clear-computer" class="card-action-row"
|
||||||
<div id="clear-computer" class="card-action-row"
|
style="display: {% if not current_model_state.computer %}none{% endif %}">
|
||||||
style="display: {% if not current_model_state.computer %}none{% endif %}">
|
<button class="card-button" onclick="clearContentType('computer')">
|
||||||
<button class="card-button" onclick="clearContentType('computer')">
|
Disable
|
||||||
Disable
|
</button>
|
||||||
</button>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="card">
|
<div class="card">
|
||||||
|
@ -61,13 +61,13 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
<div id="clear-github"
|
||||||
<div id="clear-github"
|
class="card-action-row"
|
||||||
class="card-action-row"
|
style="display: {% if not current_model_state.github %}none{% endif %}">
|
||||||
style="display: {% if not current_model_state.github %}none{% endif %}">
|
<button class="card-button" onclick="clearContentType('github')">
|
||||||
<button class="card-button" onclick="clearContentType('github')">
|
Disable
|
||||||
Disable
|
</button>
|
||||||
</button>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="card">
|
<div class="card">
|
||||||
|
@ -94,13 +94,26 @@
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
|
<div id="clear-notion"
|
||||||
|
class="card-action-row"
|
||||||
|
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
||||||
|
<button class="card-button" onclick="clearContentType('notion')">
|
||||||
|
Disable
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div id="clear-notion"
|
</div>
|
||||||
class="card-action-row"
|
</div>
|
||||||
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
<div class="general-settings section">
|
||||||
<button class="card-button" onclick="clearContentType('notion')">
|
<div id="status" style="display: none;"></div>
|
||||||
Disable
|
</div>
|
||||||
</button>
|
<div class="section finalize-actions general-settings">
|
||||||
|
<div class="section-cards">
|
||||||
|
<div class="finalize-buttons">
|
||||||
|
<button id="configure" type="submit" title="Update index with the latest changes">💾 Save All</button>
|
||||||
|
</div>
|
||||||
|
<div class="finalize-buttons">
|
||||||
|
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -221,23 +234,7 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
<div class="section general-settings">
|
<div class="section"></div>
|
||||||
<div id="results-count" title="Number of items to show in search and use for chat response">
|
|
||||||
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
|
||||||
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
|
|
||||||
</div>
|
|
||||||
<div id="status" style="display: none;"></div>
|
|
||||||
</div>
|
|
||||||
<div class="section finalize-actions general-settings">
|
|
||||||
<div class="section-cards">
|
|
||||||
<div class="finalize-buttons">
|
|
||||||
<button id="configure" type="submit" title="Update index with the latest changes">⚙️ Configure</button>
|
|
||||||
</div>
|
|
||||||
<div class="finalize-buttons">
|
|
||||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
<script>
|
<script>
|
||||||
|
|
||||||
|
@ -329,11 +326,11 @@
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
updateIndex(
|
updateIndex(
|
||||||
force=false,
|
force=false,
|
||||||
successText="Configured successfully!",
|
successText="Saved!",
|
||||||
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||||
button=configure,
|
button=configure,
|
||||||
loadingText="Configuring...",
|
loadingText="Saving...",
|
||||||
emoji="⚙️");
|
emoji="💾");
|
||||||
});
|
});
|
||||||
|
|
||||||
var reinitialize = document.getElementById("reinitialize");
|
var reinitialize = document.getElementById("reinitialize");
|
||||||
|
@ -341,7 +338,7 @@
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
updateIndex(
|
updateIndex(
|
||||||
force=true,
|
force=true,
|
||||||
successText="Reinitialized successfully!",
|
successText="Reinitialized!",
|
||||||
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||||
button=reinitialize,
|
button=reinitialize,
|
||||||
loadingText="Reinitializing...",
|
loadingText="Reinitializing...",
|
||||||
|
@ -350,6 +347,7 @@
|
||||||
|
|
||||||
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
|
const original_html = button.innerHTML;
|
||||||
button.disabled = true;
|
button.disabled = true;
|
||||||
button.innerHTML = emoji + " " + loadingText;
|
button.innerHTML = emoji + " " + loadingText;
|
||||||
fetch('/api/update?&client=web&force=' + force, {
|
fetch('/api/update?&client=web&force=' + force, {
|
||||||
|
@ -361,15 +359,17 @@
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log('Success:', data);
|
|
||||||
if (data.detail != null) {
|
if (data.detail != null) {
|
||||||
throw new Error(data.detail);
|
throw new Error(data.detail);
|
||||||
}
|
}
|
||||||
|
|
||||||
document.getElementById("status").innerHTML = emoji + " " + successText;
|
document.getElementById("status").style.display = "none";
|
||||||
document.getElementById("status").style.display = "block";
|
|
||||||
button.disabled = false;
|
button.disabled = false;
|
||||||
button.innerHTML = '✅ Done!';
|
button.innerHTML = `✅ ${successText}`;
|
||||||
|
setTimeout(function() {
|
||||||
|
button.innerHTML = original_html;
|
||||||
|
}, 2000);
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error('Error:', error);
|
console.error('Error:', error);
|
||||||
|
@ -377,6 +377,9 @@
|
||||||
document.getElementById("status").style.display = "block";
|
document.getElementById("status").style.display = "block";
|
||||||
button.disabled = false;
|
button.disabled = false;
|
||||||
button.innerHTML = '⚠️ Unsuccessful';
|
button.innerHTML = '⚠️ Unsuccessful';
|
||||||
|
setTimeout(function() {
|
||||||
|
button.innerHTML = original_html;
|
||||||
|
}, 2000);
|
||||||
});
|
});
|
||||||
|
|
||||||
content_sources = ["computer", "github", "notion"];
|
content_sources = ["computer", "github", "notion"];
|
||||||
|
@ -400,26 +403,6 @@
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup the results count slider
|
|
||||||
const resultsCountSlider = document.getElementById('results-count-slider');
|
|
||||||
const resultsCountValue = document.getElementById('results-count-value');
|
|
||||||
|
|
||||||
// Set the initial value of the slider
|
|
||||||
resultsCountValue.textContent = resultsCountSlider.value;
|
|
||||||
|
|
||||||
// Store the slider value in localStorage when it changes
|
|
||||||
resultsCountSlider.addEventListener('input', () => {
|
|
||||||
resultsCountValue.textContent = resultsCountSlider.value;
|
|
||||||
localStorage.setItem('khojResultsCount', resultsCountSlider.value);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Get the slider value from localStorage on page load
|
|
||||||
const storedResultsCount = localStorage.getItem('khojResultsCount');
|
|
||||||
if (storedResultsCount) {
|
|
||||||
resultsCountSlider.value = storedResultsCount;
|
|
||||||
resultsCountValue.textContent = storedResultsCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
function generateAPIKey() {
|
function generateAPIKey() {
|
||||||
const apiKeyList = document.getElementById("api-key-list");
|
const apiKeyList = document.getElementById("api-key-list");
|
||||||
fetch('/auth/token', {
|
fetch('/auth/token', {
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
<span class="card-title-text">Files</span>
|
<span class="card-title-text">Files</span>
|
||||||
<div class="instructions">
|
<div class="instructions">
|
||||||
<p class="card-description">Manage files from your computer</p>
|
<p class="card-description">Manage files from your computer</p>
|
||||||
<p id="get-desktop-client" class="card-description">Download the <a href="https://download.khoj.dev">Khoj Desktop app</a> to sync documents from your computer</p>
|
<p id="get-desktop-client" class="card-description">Download the <a href="https://khoj.dev/downloads">Khoj Desktop app</a> to sync documents from your computer</p>
|
||||||
</div>
|
</div>
|
||||||
</h2>
|
</h2>
|
||||||
<div class="section-manage-files">
|
<div class="section-manage-files">
|
||||||
|
|
|
@ -46,6 +46,9 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<style>
|
<style>
|
||||||
|
td {
|
||||||
|
padding: 10px 0;
|
||||||
|
}
|
||||||
div.repo {
|
div.repo {
|
||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
|
@ -124,6 +127,11 @@
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const submitButton = document.getElementById("submit");
|
||||||
|
submitButton.disabled = true;
|
||||||
|
submitButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
// Save Github config on server
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/config/data/content-source/github', {
|
fetch('/api/config/data/content-source/github', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
@ -137,15 +145,40 @@
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
|
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
|
||||||
|
.catch(error => {
|
||||||
|
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
submitButton.innerHTML = "⚠️ Failed to save settings";
|
||||||
|
setTimeout(function() {
|
||||||
|
submitButton.innerHTML = "Save";
|
||||||
|
submitButton.disabled = false;
|
||||||
|
}, 2000);
|
||||||
|
return;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Index Github content on server
|
||||||
|
fetch('/api/update?t=github')
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data["status"] == "ok") {
|
document.getElementById("success").style.display = "none";
|
||||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
submitButton.innerHTML = "✅ Successfully updated";
|
||||||
document.getElementById("success").style.display = "block";
|
setTimeout(function() {
|
||||||
} else {
|
submitButton.innerHTML = "Save";
|
||||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
submitButton.disabled = false;
|
||||||
document.getElementById("success").style.display = "block";
|
}, 2000);
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
.catch(error => {
|
||||||
|
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
submitButton.innerHTML = "⚠️ Failed to save content";
|
||||||
|
setTimeout(function() {
|
||||||
|
submitButton.innerHTML = "Save";
|
||||||
|
submitButton.disabled = false;
|
||||||
|
}, 2000);
|
||||||
|
});
|
||||||
|
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|
|
@ -41,6 +41,11 @@
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const submitButton = document.getElementById("submit");
|
||||||
|
submitButton.disabled = true;
|
||||||
|
submitButton.innerHTML = "Saving...";
|
||||||
|
|
||||||
|
// Save Notion config on server
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/config/data/content-source/notion', {
|
fetch('/api/config/data/content-source/notion', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
@ -53,15 +58,39 @@
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
|
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
|
||||||
|
.catch(error => {
|
||||||
|
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
submitButton.innerHTML = "⚠️ Failed to save settings";
|
||||||
|
setTimeout(function() {
|
||||||
|
submitButton.innerHTML = "Save";
|
||||||
|
submitButton.disabled = false;
|
||||||
|
}, 2000);
|
||||||
|
return;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Index Notion content on server
|
||||||
|
fetch('/api/update?t=notion')
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data["status"] == "ok") {
|
document.getElementById("success").style.display = "none";
|
||||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
submitButton.innerHTML = "✅ Successfully updated";
|
||||||
document.getElementById("success").style.display = "block";
|
setTimeout(function() {
|
||||||
} else {
|
submitButton.innerHTML = "Save";
|
||||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
submitButton.disabled = false;
|
||||||
document.getElementById("success").style.display = "block";
|
}, 2000);
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
.catch(error => {
|
||||||
|
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
submitButton.innerHTML = "⚠️ Failed to save content";
|
||||||
|
setTimeout(function() {
|
||||||
|
submitButton.innerHTML = "Save";
|
||||||
|
submitButton.disabled = false;
|
||||||
|
}, 2000);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
|
|
|
@ -189,7 +189,6 @@
|
||||||
})
|
})
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
console.log(data);
|
|
||||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,6 +56,7 @@ locale.setlocale(locale.LC_ALL, "")
|
||||||
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.cli import cli
|
from khoj.utils.cli import cli
|
||||||
|
from khoj.utils.initialization import initialization
|
||||||
|
|
||||||
# Setup Logger
|
# Setup Logger
|
||||||
rich_handler = RichHandler(rich_tracebacks=True)
|
rich_handler = RichHandler(rich_tracebacks=True)
|
||||||
|
@ -74,8 +75,7 @@ def run(should_start_server=True):
|
||||||
args = cli(state.cli_args)
|
args = cli(state.cli_args)
|
||||||
set_state(args)
|
set_state(args)
|
||||||
|
|
||||||
# Create app directory, if it doesn't exist
|
logger.info(f"🚒 Initializing Khoj v{state.khoj_version}")
|
||||||
state.config_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Set Logging Level
|
# Set Logging Level
|
||||||
if args.verbose == 0:
|
if args.verbose == 0:
|
||||||
|
@ -83,6 +83,11 @@ def run(should_start_server=True):
|
||||||
elif args.verbose >= 1:
|
elif args.verbose >= 1:
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
initialization()
|
||||||
|
|
||||||
|
# Create app directory, if it doesn't exist
|
||||||
|
state.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Set Log File
|
# Set Log File
|
||||||
fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
|
fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
|
||||||
fh.setLevel(logging.DEBUG)
|
fh.setLevel(logging.DEBUG)
|
||||||
|
@ -97,7 +102,7 @@ def run(should_start_server=True):
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
|
||||||
# Mount Django and Static Files
|
# Mount Django and Static Files
|
||||||
app.mount("/django", django_app, name="django")
|
app.mount("/server", django_app, name="server")
|
||||||
static_dir = "static"
|
static_dir = "static"
|
||||||
if not os.path.exists(static_dir):
|
if not os.path.exists(static_dir):
|
||||||
os.mkdir(static_dir)
|
os.mkdir(static_dir)
|
||||||
|
|
|
@ -55,10 +55,10 @@ def extract_questions_offline(
|
||||||
last_year = datetime.now().year - 1
|
last_year = datetime.now().year - 1
|
||||||
last_christmas_date = f"{last_year}-12-25"
|
last_christmas_date = f"{last_year}-12-25"
|
||||||
next_christmas_date = f"{datetime.now().year}-12-25"
|
next_christmas_date = f"{datetime.now().year}-12-25"
|
||||||
system_prompt = prompts.extract_questions_system_prompt_llamav2.format(
|
system_prompt = prompts.system_prompt_extract_questions_gpt4all.format(
|
||||||
message=(prompts.system_prompt_message_extract_questions_llamav2)
|
message=(prompts.system_prompt_message_extract_questions_gpt4all)
|
||||||
)
|
)
|
||||||
example_questions = prompts.extract_questions_llamav2_sample.format(
|
example_questions = prompts.extract_questions_gpt4all_sample.format(
|
||||||
query=text,
|
query=text,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
current_date=current_date,
|
current_date=current_date,
|
||||||
|
@ -150,14 +150,14 @@ def converse_offline(
|
||||||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
|
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
|
||||||
conversation_primer = user_query
|
conversation_primer = user_query
|
||||||
else:
|
else:
|
||||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
conversation_primer = prompts.notes_conversation_gpt4all.format(
|
||||||
query=user_query, references=compiled_references_message
|
query=user_query, references=compiled_references_message
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
conversation_primer,
|
||||||
prompts.system_prompt_message_llamav2,
|
prompts.system_prompt_message_gpt4all,
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
|
@ -183,16 +183,16 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||||
conversation_history = messages[1:-1]
|
conversation_history = messages[1:-1]
|
||||||
|
|
||||||
formatted_messages = [
|
formatted_messages = [
|
||||||
prompts.chat_history_llamav2_from_assistant.format(message=message.content)
|
prompts.khoj_message_gpt4all.format(message=message.content)
|
||||||
if message.role == "assistant"
|
if message.role == "assistant"
|
||||||
else prompts.chat_history_llamav2_from_user.format(message=message.content)
|
else prompts.user_message_gpt4all.format(message=message.content)
|
||||||
for message in conversation_history
|
for message in conversation_history
|
||||||
]
|
]
|
||||||
|
|
||||||
stop_words = ["<s>"]
|
stop_words = ["<s>"]
|
||||||
chat_history = "".join(formatted_messages)
|
chat_history = "".join(formatted_messages)
|
||||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
|
||||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
|
||||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
|
|
|
@ -20,27 +20,6 @@ from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
|
|
||||||
"""
|
|
||||||
Summarize conversation session using the specified OpenAI chat model
|
|
||||||
"""
|
|
||||||
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
|
|
||||||
|
|
||||||
# Get Response from GPT
|
|
||||||
logger.debug(f"Prompt for GPT: {messages}")
|
|
||||||
response = completion_with_backoff(
|
|
||||||
messages=messages,
|
|
||||||
model_name=model,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
model_kwargs={"stop": ['"""'], "frequency_penalty": 0.2},
|
|
||||||
openai_api_key=api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
|
||||||
return str(response.content).replace("\n\n", "")
|
|
||||||
|
|
||||||
|
|
||||||
def extract_questions(
|
def extract_questions(
|
||||||
text,
|
text,
|
||||||
model: Optional[str] = "gpt-4",
|
model: Optional[str] = "gpt-4",
|
||||||
|
@ -131,16 +110,14 @@ def converse(
|
||||||
completion_func(chat_response=prompts.no_notes_found.format())
|
completion_func(chat_response=prompts.no_notes_found.format())
|
||||||
return iter([prompts.no_notes_found.format()])
|
return iter([prompts.no_notes_found.format()])
|
||||||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
|
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
|
||||||
conversation_primer = prompts.general_conversation.format(current_date=current_date, query=user_query)
|
conversation_primer = prompts.general_conversation.format(query=user_query)
|
||||||
else:
|
else:
|
||||||
conversation_primer = prompts.notes_conversation.format(
|
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
|
||||||
current_date=current_date, query=user_query, references=compiled_references
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup Prompt with Primer or Conversation History
|
# Setup Prompt with Primer or Conversation History
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
conversation_primer,
|
conversation_primer,
|
||||||
prompts.personality.format(),
|
prompts.personality.format(current_date=current_date),
|
||||||
conversation_log,
|
conversation_log,
|
||||||
model,
|
model,
|
||||||
max_prompt_size,
|
max_prompt_size,
|
||||||
|
@ -157,4 +134,5 @@ def converse(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
completion_func=completion_func,
|
completion_func=completion_func,
|
||||||
|
model_kwargs={"stop": ["Notes:\n["]},
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
|
||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def chat_completion_with_backoff(
|
def chat_completion_with_backoff(
|
||||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
|
||||||
callback_handler = StreamingChatCallbackHandler(g)
|
callback_handler = StreamingChatCallbackHandler(g)
|
||||||
chat = ChatOpenAI(
|
chat = ChatOpenAI(
|
||||||
streaming=True,
|
streaming=True,
|
||||||
|
@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||||
model_name=model_name, # type: ignore
|
model_name=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
request_timeout=20,
|
request_timeout=20,
|
||||||
max_retries=1,
|
max_retries=1,
|
||||||
client=None,
|
client=None,
|
||||||
|
|
|
@ -4,30 +4,44 @@ from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
## Personality
|
## Personality
|
||||||
## --
|
## --
|
||||||
personality = PromptTemplate.from_template("You are Khoj, a smart, inquisitive and helpful personal assistant.")
|
personality = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||||
|
Use your general knowledge and the past conversation with the user as context to inform your responses.
|
||||||
|
You were created by Khoj Inc. with the following capabilities:
|
||||||
|
|
||||||
|
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
||||||
|
- You cannot set reminders.
|
||||||
|
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||||
|
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||||
|
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||||
|
|
||||||
|
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
|
||||||
|
Today is {current_date} in UTC.
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
## General Conversation
|
## General Conversation
|
||||||
## --
|
## --
|
||||||
general_conversation = PromptTemplate.from_template(
|
general_conversation = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
Using your general knowledge and our past conversations as context, answer the following question.
|
{query}
|
||||||
Current Date: {current_date}
|
|
||||||
|
|
||||||
Question: {query}
|
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
no_notes_found = PromptTemplate.from_template(
|
no_notes_found = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
I'm sorry, I couldn't find any relevant notes to respond to your message.
|
I'm sorry, I couldn't find any relevant notes to respond to your message.
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
|
## Conversation Prompts for GPT4All Models
|
||||||
|
## --
|
||||||
|
system_prompt_message_gpt4all = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||||
Using your general knowledge and our past conversations as context, answer the following question.
|
Using your general knowledge and our past conversations as context, answer the following question.
|
||||||
If you do not know the answer, say 'I don't know.'"""
|
If you do not know the answer, say 'I don't know.'"""
|
||||||
|
|
||||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||||
- Write the question as if you can search for the answer on the user's personal notes.
|
- Write the question as if you can search for the answer on the user's personal notes.
|
||||||
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
||||||
- Add as much context from the previous questions and notes as required into your search queries.
|
- Add as much context from the previous questions and notes as required into your search queries.
|
||||||
|
@ -35,61 +49,47 @@ system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and i
|
||||||
What follow-up questions, if any, will you need to ask to answer the user's question?
|
What follow-up questions, if any, will you need to ask to answer the user's question?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompt_llamav2 = PromptTemplate.from_template(
|
system_prompt_gpt4all = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST] <<SYS>>
|
<s>[INST] <<SYS>>
|
||||||
{message}
|
{message}
|
||||||
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
|
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
|
system_prompt_extract_questions_gpt4all = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST] <<SYS>>
|
<s>[INST] <<SYS>>
|
||||||
{message}
|
{message}
|
||||||
<</SYS>>[/INST]</s>"""
|
<</SYS>>[/INST]</s>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
general_conversation_llamav2 = PromptTemplate.from_template(
|
user_message_gpt4all = PromptTemplate.from_template(
|
||||||
"""
|
|
||||||
<s>[INST] {query} [/INST]
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_history_llamav2_from_user = PromptTemplate.from_template(
|
|
||||||
"""
|
"""
|
||||||
<s>[INST] {message} [/INST]
|
<s>[INST] {message} [/INST]
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_history_llamav2_from_assistant = PromptTemplate.from_template(
|
khoj_message_gpt4all = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
{message}</s>
|
{message}</s>
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_llamav2 = PromptTemplate.from_template(
|
|
||||||
"""
|
|
||||||
<s>[INST] {query} [/INST]
|
|
||||||
""".strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
## Notes Conversation
|
## Notes Conversation
|
||||||
## --
|
## --
|
||||||
notes_conversation = PromptTemplate.from_template(
|
notes_conversation = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
Using my personal notes and our past conversations as context, answer the following question.
|
Use my personal notes and our past conversations to inform your response.
|
||||||
Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
||||||
These questions should end with a question mark.
|
|
||||||
Current Date: {current_date}
|
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
{references}
|
{references}
|
||||||
|
|
||||||
Question: {query}
|
Query: {query}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
notes_conversation_llamav2 = PromptTemplate.from_template(
|
notes_conversation_gpt4all = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
User's Notes:
|
User's Notes:
|
||||||
{references}
|
{references}
|
||||||
|
@ -98,13 +98,6 @@ Question: {query}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
## Summarize Chat
|
|
||||||
## --
|
|
||||||
summarize_chat = PromptTemplate.from_template(
|
|
||||||
f"{personality.format()} Summarize the conversation from your first person perspective"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
## Summarize Notes
|
## Summarize Notes
|
||||||
## --
|
## --
|
||||||
summarize_notes = PromptTemplate.from_template(
|
summarize_notes = PromptTemplate.from_template(
|
||||||
|
@ -132,7 +125,10 @@ Question: {user_query}
|
||||||
Answer (in second person):"""
|
Answer (in second person):"""
|
||||||
)
|
)
|
||||||
|
|
||||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
|
||||||
|
## Extract Questions
|
||||||
|
## --
|
||||||
|
extract_questions_gpt4all_sample = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
|
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
|
||||||
<s>[INST] How was my trip to Cambodia? [/INST]
|
<s>[INST] How was my trip to Cambodia? [/INST]
|
||||||
|
@ -157,8 +153,6 @@ Use these notes from the user's previous conversations to provide a response:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
## Extract Questions
|
|
||||||
## --
|
|
||||||
extract_questions = PromptTemplate.from_template(
|
extract_questions = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
|
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
|
||||||
|
|
|
@ -27,5 +27,5 @@ class CrossEncoderModel:
|
||||||
|
|
||||||
def predict(self, query, hits: List[SearchResponse]):
|
def predict(self, query, hits: List[SearchResponse]):
|
||||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
||||||
cross_scores = self.cross_encoder_model.predict(cross__inp)
|
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
||||||
return cross_scores
|
return cross_scores
|
||||||
|
|
|
@ -7,7 +7,7 @@ import json
|
||||||
from typing import List, Optional, Union, Any
|
from typing import List, Optional, Union, Any
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ from khoj.routers.helpers import (
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
|
ApiUserRateLimiter,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.prompts import help_message
|
from khoj.processor.conversation.prompts import help_message
|
||||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
|
@ -177,11 +178,15 @@ async def set_content_config_github_data(
|
||||||
|
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
await adapters.set_user_github_config(
|
try:
|
||||||
user=user,
|
await adapters.set_user_github_config(
|
||||||
pat_token=updated_config.pat_token,
|
user=user,
|
||||||
repos=updated_config.repos,
|
pat_token=updated_config.pat_token,
|
||||||
)
|
repos=updated_config.repos,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -205,10 +210,14 @@ async def set_content_config_notion_data(
|
||||||
|
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
await adapters.set_notion_config(
|
try:
|
||||||
user=user,
|
await adapters.set_notion_config(
|
||||||
token=updated_config.token,
|
user=user,
|
||||||
)
|
token=updated_config.token,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e, exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -348,7 +357,7 @@ async def search(
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
t: Optional[SearchType] = SearchType.All,
|
t: Optional[SearchType] = SearchType.All,
|
||||||
r: Optional[bool] = False,
|
r: Optional[bool] = False,
|
||||||
score_threshold: Optional[Union[float, None]] = None,
|
max_distance: Optional[Union[float, None]] = None,
|
||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
|
@ -367,12 +376,12 @@ async def search(
|
||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n or 5
|
results_count = n or 5
|
||||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
max_distance = max_distance if max_distance is not None else math.inf
|
||||||
search_futures: List[concurrent.futures.Future] = []
|
search_futures: List[concurrent.futures.Future] = []
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
if user:
|
if user:
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||||
if query_cache_key in state.query_cache[user.uuid]:
|
if query_cache_key in state.query_cache[user.uuid]:
|
||||||
logger.debug(f"Return response from query cache")
|
logger.debug(f"Return response from query cache")
|
||||||
return state.query_cache[user.uuid][query_cache_key]
|
return state.query_cache[user.uuid][query_cache_key]
|
||||||
|
@ -409,8 +418,7 @@ async def search(
|
||||||
user_query,
|
user_query,
|
||||||
t,
|
t,
|
||||||
question_embedding=encoded_asymmetric_query,
|
question_embedding=encoded_asymmetric_query,
|
||||||
rank_results=r or False,
|
max_distance=max_distance,
|
||||||
score_threshold=score_threshold,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -423,7 +431,6 @@ async def search(
|
||||||
results_count,
|
results_count,
|
||||||
state.search_models.image_search,
|
state.search_models.image_search,
|
||||||
state.content_index.image,
|
state.content_index.image,
|
||||||
score_threshold=score_threshold,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -446,11 +453,10 @@ async def search(
|
||||||
# Collate results
|
# Collate results
|
||||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||||
|
|
||||||
if r:
|
|
||||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
|
||||||
else:
|
|
||||||
# Sort results across all content types and take top results
|
# Sort results across all content types and take top results
|
||||||
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||||
|
:results_count
|
||||||
|
]
|
||||||
|
|
||||||
# Cache results
|
# Cache results
|
||||||
if user:
|
if user:
|
||||||
|
@ -575,11 +581,14 @@ async def chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
q: str,
|
q: str,
|
||||||
n: Optional[int] = 5,
|
n: Optional[int] = 5,
|
||||||
|
d: Optional[float] = 0.15,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
|
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||||
|
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
|
|
||||||
|
@ -591,7 +600,7 @@ async def chat(
|
||||||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, meta_log, q, (n or 5), conversation_command
|
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||||
|
@ -606,7 +615,7 @@ async def chat(
|
||||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
llm_response = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
meta_log,
|
meta_log,
|
||||||
compiled_references,
|
compiled_references,
|
||||||
|
@ -615,6 +624,19 @@ async def chat(
|
||||||
user,
|
user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chat_metadata.update({"conversation_command": conversation_command.value})
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
client=client,
|
||||||
|
user_agent=user_agent,
|
||||||
|
referer=referer,
|
||||||
|
host=host,
|
||||||
|
metadata=chat_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
if llm_response is None:
|
if llm_response is None:
|
||||||
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
||||||
|
|
||||||
|
@ -634,16 +656,6 @@ async def chat(
|
||||||
|
|
||||||
response_obj = {"response": actual_response, "context": compiled_references}
|
response_obj = {"response": actual_response, "context": compiled_references}
|
||||||
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="chat",
|
|
||||||
client=client,
|
|
||||||
user_agent=user_agent,
|
|
||||||
referer=referer,
|
|
||||||
host=host,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ -652,6 +664,7 @@ async def extract_references_and_questions(
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
q: str,
|
q: str,
|
||||||
n: int,
|
n: int,
|
||||||
|
d: float,
|
||||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
@ -663,7 +676,7 @@ async def extract_references_and_questions(
|
||||||
if conversation_type == ConversationCommand.General:
|
if conversation_type == ConversationCommand.General:
|
||||||
return compiled_references, inferred_queries, q
|
return compiled_references, inferred_queries, q
|
||||||
|
|
||||||
if not sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||||
)
|
)
|
||||||
|
@ -712,7 +725,7 @@ async def extract_references_and_questions(
|
||||||
request=request,
|
request=request,
|
||||||
n=n_items,
|
n=n_items,
|
||||||
r=True,
|
r=True,
|
||||||
score_threshold=-5.0,
|
max_distance=d,
|
||||||
dedupe=False,
|
dedupe=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from google.auth.transport import requests as google_requests
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
|
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
|
||||||
from database.models import KhojApiUser
|
from khoj.routers.helpers import update_telemetry_state
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,6 +100,16 @@ async def auth(request: Request):
|
||||||
if khoj_user:
|
if khoj_user:
|
||||||
request.session["user"] = dict(idinfo)
|
request.session["user"] = dict(idinfo)
|
||||||
|
|
||||||
|
if not khoj_user.last_login:
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="create_user",
|
||||||
|
metadata={"user_id": str(khoj_user.uuid)},
|
||||||
|
)
|
||||||
|
logger.log(logging.INFO, f"New User Created: {khoj_user.uuid}")
|
||||||
|
RedirectResponse(url="/?status=welcome")
|
||||||
|
|
||||||
return RedirectResponse(url="/")
|
return RedirectResponse(url="/")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,27 @@
|
||||||
import logging
|
# Standard Packages
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Iterator, List, Optional, Union
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from time import time
|
||||||
|
from typing import Iterator, List, Optional, Union, Tuple, Dict
|
||||||
|
|
||||||
|
# External Packages
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import GPT4AllProcessorModel
|
from khoj.utils.config import GPT4AllProcessorModel
|
||||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||||
from khoj.processor.conversation.openai.gpt import converse
|
from khoj.processor.conversation.openai.gpt import converse
|
||||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||||
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser, Subscription
|
||||||
from database.adapters import ConversationAdapters
|
from database.adapters import ConversationAdapters
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=1)
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
@ -61,12 +67,15 @@ def update_telemetry_state(
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||||
|
subscription: Subscription = user.subscription if user and user.subscription else None
|
||||||
user_state = {
|
user_state = {
|
||||||
"client_host": request.client.host if request.client else None,
|
"client_host": request.client.host if request.client else None,
|
||||||
"user_agent": user_agent or "unknown",
|
"user_agent": user_agent or "unknown",
|
||||||
"referer": referer or "unknown",
|
"referer": referer or "unknown",
|
||||||
"host": host or "unknown",
|
"host": host or "unknown",
|
||||||
"server_id": str(user.uuid) if user else None,
|
"server_id": str(user.uuid) if user else None,
|
||||||
|
"subscription_type": subscription.type if subscription else None,
|
||||||
|
"is_recurring": subscription.is_recurring if subscription else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
|
@ -109,7 +118,7 @@ def generate_chat_response(
|
||||||
inferred_queries: List[str] = [],
|
inferred_queries: List[str] = [],
|
||||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
def _save_to_conversation_log(
|
def _save_to_conversation_log(
|
||||||
q: str,
|
q: str,
|
||||||
chat_response: str,
|
chat_response: str,
|
||||||
|
@ -132,6 +141,8 @@ def generate_chat_response(
|
||||||
chat_response = None
|
chat_response = None
|
||||||
logger.debug(f"Conversation Type: {conversation_command.name}")
|
logger.debug(f"Conversation Type: {conversation_command.name}")
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
_save_to_conversation_log,
|
_save_to_conversation_log,
|
||||||
|
@ -148,8 +159,8 @@ def generate_chat_response(
|
||||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||||
if state.gpt4all_processor_config.loaded_model is None:
|
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||||
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
|
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||||
|
|
||||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
|
@ -179,8 +190,33 @@ def generate_chat_response(
|
||||||
tokenizer_name=conversation_config.tokenizer,
|
tokenizer_name=conversation_config.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e, exc_info=True)
|
logger.error(e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return chat_response
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class ApiUserRateLimiter:
|
||||||
|
def __init__(self, requests: int, window: int):
|
||||||
|
self.requests = requests
|
||||||
|
self.window = window
|
||||||
|
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
def __call__(self, request: Request):
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
user_requests = self.cache[user.uuid]
|
||||||
|
|
||||||
|
# Remove requests outside of the time window
|
||||||
|
cutoff = time() - self.window
|
||||||
|
while user_requests and user_requests[0] < cutoff:
|
||||||
|
user_requests.pop(0)
|
||||||
|
|
||||||
|
# Check if the user has exceeded the rate limit
|
||||||
|
if len(user_requests) >= self.requests:
|
||||||
|
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||||
|
|
||||||
|
# Add the current request to the cache
|
||||||
|
user_requests.append(time())
|
||||||
|
|
|
@ -146,7 +146,7 @@ def extract_metadata(image_name):
|
||||||
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
|
||||||
):
|
):
|
||||||
# Set query to image content if query is of form file:/path/to/file.png
|
# Set query to image content if query is of form file:/path/to/file.png
|
||||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||||
|
@ -167,7 +167,8 @@ async def query(
|
||||||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||||
with timer("Search Time", logger):
|
with timer("Search Time", logger):
|
||||||
image_hits = {
|
image_hits = {
|
||||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
# Map scores to distance metric by multiplying by -1
|
||||||
|
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
|
||||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,7 +205,7 @@ async def query(
|
||||||
]
|
]
|
||||||
|
|
||||||
# Filter results by score threshold
|
# Filter results by score threshold
|
||||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
|
||||||
|
|
||||||
# Sort the images based on their combined metadata, image scores
|
# Sort the images based on their combined metadata, image scores
|
||||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||||
|
|
|
@ -104,8 +104,7 @@ async def query(
|
||||||
raw_query: str,
|
raw_query: str,
|
||||||
type: SearchType = SearchType.All,
|
type: SearchType = SearchType.All,
|
||||||
question_embedding: Union[torch.Tensor, None] = None,
|
question_embedding: Union[torch.Tensor, None] = None,
|
||||||
rank_results: bool = False,
|
max_distance: float = math.inf,
|
||||||
score_threshold: float = -math.inf,
|
|
||||||
) -> Tuple[List[dict], List[Entry]]:
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
|
|
||||||
|
@ -127,6 +126,7 @@ async def query(
|
||||||
max_results=top_k,
|
max_results=top_k,
|
||||||
file_type_filter=file_type,
|
file_type_filter=file_type,
|
||||||
raw_query=raw_query,
|
raw_query=raw_query,
|
||||||
|
max_distance=max_distance,
|
||||||
).all()
|
).all()
|
||||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
@ -177,12 +177,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def rerank_and_sort_results(hits, query):
|
def rerank_and_sort_results(hits, query, rank_results):
|
||||||
|
# If we have more than one result and reranking is enabled
|
||||||
|
rank_results = rank_results and len(list(hits)) > 1
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
hits = cross_encoder_score(query, hits)
|
if rank_results:
|
||||||
|
hits = cross_encoder_score(query, hits)
|
||||||
|
|
||||||
# Sort results by cross-encoder score followed by bi-encoder score
|
# Sort results by cross-encoder score followed by bi-encoder score
|
||||||
hits = sort_results(rank_results=True, hits=hits)
|
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
|
@ -217,9 +221,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
||||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||||
|
|
||||||
# Store cross-encoder scores in results dictionary for ranking
|
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||||
for idx in range(len(cross_scores)):
|
for idx in range(len(cross_scores)):
|
||||||
hits[idx]["cross_score"] = cross_scores[idx]
|
hits[idx]["cross_score"] = 1 - cross_scores[idx]
|
||||||
|
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
|
@ -227,7 +231,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
||||||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||||
"""Order results by cross-encoder score followed by bi-encoder score"""
|
"""Order results by cross-encoder score followed by bi-encoder score"""
|
||||||
with timer("Rank Time", logger, state.device):
|
with timer("Rank Time", logger, state.device):
|
||||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
|
||||||
return hits
|
return hits
|
||||||
|
|
|
@ -6,6 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
|
||||||
app_env_filepath = "~/.khoj/env"
|
app_env_filepath = "~/.khoj/env"
|
||||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||||
content_directory = "~/.khoj/content/"
|
content_directory = "~/.khoj/content/"
|
||||||
|
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||||
|
|
||||||
empty_config = {
|
empty_config = {
|
||||||
"search-type": {
|
"search-type": {
|
||||||
|
|
98
src/khoj/utils/initialization.py
Normal file
98
src/khoj/utils/initialization.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from database.models import (
|
||||||
|
KhojUser,
|
||||||
|
OfflineChatProcessorConversationConfig,
|
||||||
|
OpenAIProcessorConversationConfig,
|
||||||
|
ChatModelOptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
from khoj.utils.constants import default_offline_chat_model
|
||||||
|
|
||||||
|
from database.adapters import ConversationAdapters
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def initialization():
|
||||||
|
def _create_admin_user():
|
||||||
|
logger.info(
|
||||||
|
"👩✈️ Setting up admin user. These credentials will allow you to configure your server at /server/admin."
|
||||||
|
)
|
||||||
|
email_addr = os.getenv("KHOJ_ADMIN_EMAIL") or input("Email: ")
|
||||||
|
password = os.getenv("KHOJ_ADMIN_PASSWORD") or input("Password: ")
|
||||||
|
admin_user = KhojUser.objects.create_superuser(email=email_addr, username=email_addr, password=password)
|
||||||
|
logger.info(f"👩✈️ Created admin user: {admin_user.email}")
|
||||||
|
|
||||||
|
def _create_chat_configuration():
|
||||||
|
logger.info(
|
||||||
|
"🗣️ Configure chat models available to your server. You can always update these at /server/admin using the credentials of your admin account"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Some environments don't support interactive input. We catch the exception and return if that's the case. The admin can still configure their settings from the admin page.
|
||||||
|
input()
|
||||||
|
except EOFError:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Note: gpt4all package is not available on all devices.
|
||||||
|
# So ensure gpt4all package is installed before continuing this step.
|
||||||
|
import gpt4all
|
||||||
|
|
||||||
|
use_offline_model = input("Use offline chat model? (y/n): ")
|
||||||
|
if use_offline_model == "y":
|
||||||
|
logger.info("🗣️ Setting up offline chat model")
|
||||||
|
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
|
||||||
|
|
||||||
|
offline_chat_model = input(
|
||||||
|
f"Enter the name of the offline chat model you want to use, based on the models in HuggingFace (press enter to use the default: {default_offline_chat_model}): "
|
||||||
|
)
|
||||||
|
if offline_chat_model == "":
|
||||||
|
ChatModelOptions.objects.create(
|
||||||
|
chat_model=default_offline_chat_model, model_type=ChatModelOptions.ModelType.OFFLINE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_tokens = input("Enter the maximum number of tokens to use for the offline chat model:")
|
||||||
|
tokenizer = input("Enter the tokenizer to use for the offline chat model:")
|
||||||
|
ChatModelOptions.objects.create(
|
||||||
|
chat_model=offline_chat_model,
|
||||||
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
logger.warning("Offline models are not supported on this device.")
|
||||||
|
|
||||||
|
use_openai_model = input("Use OpenAI chat model? (y/n): ")
|
||||||
|
|
||||||
|
if use_openai_model == "y":
|
||||||
|
logger.info("🗣️ Setting up OpenAI chat model")
|
||||||
|
api_key = input("Enter your OpenAI API key: ")
|
||||||
|
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
|
||||||
|
openai_chat_model = input("Enter the name of the OpenAI chat model you want to use: ")
|
||||||
|
max_tokens = input("Enter the maximum number of tokens to use for the OpenAI chat model:")
|
||||||
|
ChatModelOptions.objects.create(
|
||||||
|
chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
|
||||||
|
admin_user = KhojUser.objects.filter(is_staff=True).first()
|
||||||
|
if admin_user is None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_create_admin_user()
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
|
||||||
|
|
||||||
|
chat_config = ConversationAdapters.get_default_conversation_config()
|
||||||
|
if admin_user is None and chat_config is None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
_create_chat_configuration()
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)
|
|
@ -43,6 +43,7 @@ from tests.helpers import (
|
||||||
OpenAIProcessorConversationConfigFactory,
|
OpenAIProcessorConversationConfigFactory,
|
||||||
OfflineChatProcessorConversationConfigFactory,
|
OfflineChatProcessorConversationConfigFactory,
|
||||||
UserConversationProcessorConfigFactory,
|
UserConversationProcessorConfigFactory,
|
||||||
|
SubscriptionFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def default_user():
|
def default_user():
|
||||||
return UserFactory()
|
user = UserFactory()
|
||||||
|
SubscriptionFactory(user=user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@ -78,11 +81,31 @@ def default_user2():
|
||||||
if KhojUser.objects.filter(username="default").exists():
|
if KhojUser.objects.filter(username="default").exists():
|
||||||
return KhojUser.objects.get(username="default")
|
return KhojUser.objects.get(username="default")
|
||||||
|
|
||||||
return KhojUser.objects.create(
|
user = KhojUser.objects.create(
|
||||||
username="default",
|
username="default",
|
||||||
email="default@example.com",
|
email="default@example.com",
|
||||||
password="default",
|
password="default",
|
||||||
)
|
)
|
||||||
|
SubscriptionFactory(user=user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.fixture
|
||||||
|
def default_user3():
|
||||||
|
"""
|
||||||
|
This user should not have any data associated with it
|
||||||
|
"""
|
||||||
|
if KhojUser.objects.filter(username="default3").exists():
|
||||||
|
return KhojUser.objects.get(username="default3")
|
||||||
|
|
||||||
|
user = KhojUser.objects.create(
|
||||||
|
username="default3",
|
||||||
|
email="default3@example.com",
|
||||||
|
password="default3",
|
||||||
|
)
|
||||||
|
SubscriptionFactory(user=user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
|
@ -111,6 +134,19 @@ def api_user2(default_user2):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.fixture
|
||||||
|
def api_user3(default_user3):
|
||||||
|
if KhojApiUser.objects.filter(user=default_user3).exists():
|
||||||
|
return KhojApiUser.objects.get(user=default_user3)
|
||||||
|
|
||||||
|
return KhojApiUser.objects.create(
|
||||||
|
user=default_user3,
|
||||||
|
name="api-key",
|
||||||
|
token="kk-diff-secret-3",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def search_models(search_config: SearchConfig):
|
def search_models(search_config: SearchConfig):
|
||||||
search_models = SearchModels()
|
search_models = SearchModels()
|
||||||
|
@ -206,7 +242,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
||||||
OpenAIProcessorConversationConfigFactory()
|
OpenAIProcessorConversationConfigFactory()
|
||||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
||||||
|
|
||||||
state.anonymous_mode = False
|
state.anonymous_mode = True
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@ -224,7 +260,9 @@ def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUs
|
||||||
|
|
||||||
# Initialize Processor from Config
|
# Initialize Processor from Config
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
|
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||||
OpenAIProcessorConversationConfigFactory()
|
OpenAIProcessorConversationConfigFactory()
|
||||||
|
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
||||||
|
|
||||||
state.anonymous_mode = True
|
state.anonymous_mode = True
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from database.models import (
|
||||||
OpenAIProcessorConversationConfig,
|
OpenAIProcessorConversationConfig,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
Conversation,
|
Conversation,
|
||||||
|
Subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,3 +69,13 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
||||||
model = Conversation
|
model = Conversation
|
||||||
|
|
||||||
user = factory.SubFactory(UserFactory)
|
user = factory.SubFactory(UserFactory)
|
||||||
|
|
||||||
|
|
||||||
|
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||||
|
class Meta:
|
||||||
|
model = Subscription
|
||||||
|
|
||||||
|
user = factory.SubFactory(UserFactory)
|
||||||
|
type = "standard"
|
||||||
|
is_recurring = False
|
||||||
|
renewal_date = "2100-04-01"
|
||||||
|
|
|
@ -16,7 +16,7 @@ from khoj.utils.state import search_models, content_index, config
|
||||||
from khoj.search_type import text_search, image_search
|
from khoj.search_type import text_search, image_search
|
||||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||||
from database.models import KhojUser
|
from database.models import KhojUser, KhojApiUser
|
||||||
from database.adapters import EntryAdapters
|
from database.adapters import EntryAdapters
|
||||||
|
|
||||||
|
|
||||||
|
@ -351,6 +351,24 @@ def test_different_user_data_not_accessed(client, sample_org_data, default_user:
|
||||||
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
|
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiUser):
|
||||||
|
# Arrange
|
||||||
|
token = api_user3.token
|
||||||
|
headers = {"Authorization": "Bearer " + token}
|
||||||
|
user_query = quote("How to git install application?")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
# assert actual response has no data as the default_user3, though other users have data
|
||||||
|
assert len(response.json()) == 0
|
||||||
|
assert response.json() == []
|
||||||
|
|
||||||
|
|
||||||
def get_sample_files_data():
|
def get_sample_files_data():
|
||||||
return [
|
return [
|
||||||
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
||||||
|
|
|
@ -307,6 +307,8 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
|
||||||
"which one is",
|
"which one is",
|
||||||
"which of namita's sons",
|
"which of namita's sons",
|
||||||
"the birth order",
|
"the birth order",
|
||||||
|
"provide more context",
|
||||||
|
"provide me with more context",
|
||||||
]
|
]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||||
|
|
Loading…
Reference in a new issue