mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Resolve merge conflicts in auth.py with remove KhojApiUser import
This commit is contained in:
commit
6b17aeb32d
35 changed files with 605 additions and 266 deletions
|
@ -10,7 +10,15 @@ services:
|
|||
POSTGRES_DB: postgres
|
||||
volumes:
|
||||
- khoj_db:/var/lib/postgresql/data/
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
server:
|
||||
depends_on:
|
||||
database:
|
||||
condition: service_healthy
|
||||
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
|
||||
image: ghcr.io/khoj-ai/khoj:latest
|
||||
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
|
||||
|
@ -24,20 +32,6 @@ services:
|
|||
- "42110:42110"
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- .:/app
|
||||
# These mounted volumes hold the raw data that should be indexed for search.
|
||||
# The path in your local directory (left hand side)
|
||||
# points to the files you want to index.
|
||||
# The path of the mounted directory (right hand side),
|
||||
# must match the path prefix in your config file.
|
||||
- ./tests/data/org/:/data/org/
|
||||
- ./tests/data/images/:/data/images/
|
||||
- ./tests/data/markdown/:/data/markdown/
|
||||
- ./tests/data/pdf/:/data/pdf/
|
||||
# Embeddings and models are populated after the first run
|
||||
# You can set these volumes to point to empty directories on host
|
||||
- ./tests/data/embeddings/:/root/.khoj/content/
|
||||
- ./tests/data/models/:/root/.khoj/search/
|
||||
- khoj_config:/root/.khoj/
|
||||
- khoj_models:/root/.cache/torch/sentence_transformers
|
||||
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
|
||||
|
@ -47,9 +41,11 @@ services:
|
|||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_HOST=database
|
||||
- POSTGRES_PORT=5432
|
||||
- GOOGLE_CLIENT_SECRET=bar
|
||||
- GOOGLE_CLIENT_ID=foo
|
||||
command: --host="0.0.0.0" --port=42110 -vv
|
||||
- KHOJ_DJANGO_SECRET_KEY=secret
|
||||
- KHOJ_DEBUG=True
|
||||
- KHOJ_ADMIN_EMAIL=username@example.com
|
||||
- KHOJ_ADMIN_PASSWORD=password
|
||||
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode
|
||||
|
||||
|
||||
volumes:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
import secrets
|
||||
|
@ -99,6 +100,8 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||
user=user,
|
||||
)
|
||||
|
||||
await Subscription.objects.acreate(user=user, type="trial")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
@ -433,12 +436,19 @@ class EntryAdapters:
|
|||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
|
||||
user: KhojUser,
|
||||
embeddings: Tensor,
|
||||
max_results: int = 10,
|
||||
file_type_filter: str = None,
|
||||
raw_query: str = None,
|
||||
max_distance: float = math.inf,
|
||||
):
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||
|
||||
if file_type_filter:
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
relevant_entries = relevant_entries.order_by("distance")
|
||||
|
|
|
@ -8,6 +8,7 @@ from database.models import (
|
|||
ChatModelOptions,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
admin.site.register(KhojUser, UserAdmin)
|
||||
|
@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
|
|||
admin.site.register(ChatModelOptions)
|
||||
admin.site.register(OpenAIProcessorConversationConfig)
|
||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||
admin.site.register(Subscription)
|
||||
|
|
21
src/database/migrations/0015_alter_subscription_user.py
Normal file
21
src/database/migrations/0015_alter_subscription_user.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-11 05:39
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0014_alter_googleuser_picture"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="subscription",
|
||||
name="user",
|
||||
field=models.OneToOneField(
|
||||
on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-11 06:15
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0015_alter_subscription_user"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="subscription",
|
||||
name="renewal_date",
|
||||
field=models.DateTimeField(blank=True, default=None, null=True),
|
||||
),
|
||||
]
|
|
@ -51,10 +51,10 @@ class Subscription(BaseModel):
|
|||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
|
||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||
is_recurring = models.BooleanField(default=False)
|
||||
renewal_date = models.DateTimeField(null=True, default=None)
|
||||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
|
|
|
@ -577,12 +577,12 @@
|
|||
cursor: pointer;
|
||||
transition: background 0.2s ease-in-out;
|
||||
text-align: left;
|
||||
max-height: 50px;
|
||||
max-height: 75px;
|
||||
transition: max-height 0.3s ease-in-out;
|
||||
overflow: hidden;
|
||||
}
|
||||
button.reference-button.expanded {
|
||||
max-height: 200px;
|
||||
max-height: none;
|
||||
}
|
||||
|
||||
button.reference-button::before {
|
||||
|
|
|
@ -198,12 +198,6 @@ khojKeyInput.addEventListener('blur', async () => {
|
|||
khojKeyInput.value = token;
|
||||
});
|
||||
|
||||
const syncButton = document.getElementById('sync-data');
|
||||
syncButton.addEventListener('click', async () => {
|
||||
loadingBar.style.display = 'block';
|
||||
await window.syncDataAPI.syncData(false);
|
||||
});
|
||||
|
||||
const syncForceButton = document.getElementById('sync-force');
|
||||
syncForceButton.addEventListener('click', async () => {
|
||||
loadingBar.style.display = 'block';
|
||||
|
|
|
@ -188,7 +188,6 @@
|
|||
fetch(url, { headers })
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { Notice, Plugin } from 'obsidian';
|
||||
import { Notice, Plugin, request } from 'obsidian';
|
||||
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
|
||||
import { KhojSearchModal } from 'src/search_modal'
|
||||
import { KhojChatModal } from 'src/chat_modal'
|
||||
|
@ -69,6 +69,25 @@ export default class Khoj extends Plugin {
|
|||
async loadSettings() {
|
||||
// Load khoj obsidian plugin settings
|
||||
this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData());
|
||||
|
||||
// Check if khoj backend is configured, note if cannot connect to backend
|
||||
let headers = { "Authorization": `Bearer ${this.settings.khojApiKey}` };
|
||||
|
||||
if (this.settings.khojUrl === "https://app.khoj.dev") {
|
||||
if (this.settings.khojApiKey === "") {
|
||||
new Notice(`❗️Khoj API key is not configured. Please visit https://app.khoj.dev to get an API key.`);
|
||||
return;
|
||||
}
|
||||
|
||||
await request({ url: this.settings.khojUrl ,method: "GET", headers: headers })
|
||||
.then(response => {
|
||||
this.settings.connectedToBackend = true;
|
||||
})
|
||||
.catch(error => {
|
||||
this.settings.connectedToBackend = false;
|
||||
new Notice(`❗️Ensure Khoj backend is running and Khoj URL is pointing to it in the plugin settings.\n\n${error}`);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async saveSettings() {
|
||||
|
|
|
@ -28,7 +28,7 @@ from khoj.utils.config import (
|
|||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from database.models import KhojUser
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import get_all_users
|
||||
|
||||
|
||||
|
@ -54,27 +54,37 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
|
||||
def _initialize_default_user(self):
|
||||
if not self.khojuser_manager.filter(username="default").exists():
|
||||
self.khojuser_manager.create_user(
|
||||
default_user = self.khojuser_manager.create_user(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01")
|
||||
|
||||
async def authenticate(self, request: HTTPConnection):
|
||||
current_user = request.session.get("user")
|
||||
if current_user and current_user.get("email"):
|
||||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||
user = (
|
||||
await self.khojuser_manager.filter(email=current_user.get("email"))
|
||||
.prefetch_related("subscription")
|
||||
.afirst()
|
||||
)
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||
# Get bearer token from header
|
||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||
# Get user owning token
|
||||
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst()
|
||||
user_with_token = (
|
||||
await self.khojapiuser_manager.filter(token=bearer_token)
|
||||
.select_related("user")
|
||||
.prefetch_related("user__subscription")
|
||||
.afirst()
|
||||
)
|
||||
if user_with_token:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").afirst()
|
||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
|
||||
|
|
|
@ -109,7 +109,7 @@
|
|||
display: grid;
|
||||
grid-template-rows: repeat(3, 1fr);
|
||||
gap: 8px;
|
||||
padding: 24px 16px;
|
||||
padding: 24px 16px 8px;
|
||||
width: 320px;
|
||||
height: 180px;
|
||||
background: var(--background-color);
|
||||
|
@ -121,7 +121,7 @@
|
|||
div.finalize-buttons {
|
||||
display: grid;
|
||||
gap: 8px;
|
||||
padding: 24px 16px;
|
||||
padding: 32px 0px 0px;
|
||||
width: 320px;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
|
@ -162,10 +162,13 @@
|
|||
color: grey;
|
||||
font-size: 16px;
|
||||
}
|
||||
.card-button-row {
|
||||
.card-description-row {
|
||||
padding-top: 4px;
|
||||
}
|
||||
.card-action-row {
|
||||
display: grid;
|
||||
grid-template-columns: auto;
|
||||
text-align: right;
|
||||
grid-auto-flow: row;
|
||||
justify-content: left;
|
||||
}
|
||||
.card-button {
|
||||
border: none;
|
||||
|
@ -271,7 +274,9 @@
|
|||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
|
||||
#status {
|
||||
padding-top: 32px;
|
||||
}
|
||||
div.finalize-actions {
|
||||
grid-auto-flow: column;
|
||||
grid-gap: 24px;
|
||||
|
@ -287,6 +292,7 @@
|
|||
|
||||
select#chat-models {
|
||||
margin-bottom: 0;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
|
||||
|
@ -343,6 +349,12 @@
|
|||
width: auto;
|
||||
}
|
||||
|
||||
#status {
|
||||
padding-top: 12px;
|
||||
}
|
||||
div.finalize-actions {
|
||||
padding: 12px 0 0;
|
||||
}
|
||||
div.finalize-buttons {
|
||||
padding: 0;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
let escaped_ref = reference.replaceAll('"', '"');
|
||||
|
||||
// Generate HTML for Chat Reference
|
||||
let short_ref = escaped_ref.slice(0, 100);
|
||||
let short_ref = escaped_ref.slice(0, 140);
|
||||
short_ref = short_ref.length < escaped_ref.length ? short_ref + "..." : short_ref;
|
||||
let referenceButton = document.createElement('button');
|
||||
referenceButton.innerHTML = short_ref;
|
||||
|
@ -205,8 +205,11 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
|
||||
const currentHTML = newResponseText.innerHTML;
|
||||
newResponseText.innerHTML = formatHTMLMessage(currentHTML);
|
||||
newResponseText.appendChild(references);
|
||||
if (references != null) {
|
||||
newResponseText.appendChild(references);
|
||||
}
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -265,7 +268,6 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
});
|
||||
}
|
||||
readStream();
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -417,6 +419,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
display: block;
|
||||
}
|
||||
|
||||
div.references {
|
||||
padding-top: 8px;
|
||||
}
|
||||
div.reference {
|
||||
display: grid;
|
||||
grid-template-rows: auto;
|
||||
|
@ -447,12 +452,12 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
cursor: pointer;
|
||||
transition: background 0.2s ease-in-out;
|
||||
text-align: left;
|
||||
max-height: 50px;
|
||||
max-height: 75px;
|
||||
transition: max-height 0.3s ease-in-out;
|
||||
overflow: hidden;
|
||||
}
|
||||
button.reference-button.expanded {
|
||||
max-height: 200px;
|
||||
max-height: none;
|
||||
}
|
||||
|
||||
button.reference-button::before {
|
||||
|
|
|
@ -29,12 +29,12 @@
|
|||
{% endif %}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
<div id="clear-computer" class="card-action-row"
|
||||
style="display: {% if not current_model_state.computer %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('computer')">
|
||||
Disable
|
||||
</button>
|
||||
<div id="clear-computer" class="card-action-row"
|
||||
style="display: {% if not current_model_state.computer %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('computer')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
|
@ -61,13 +61,13 @@
|
|||
{% endif %}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
<div id="clear-github"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.github %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('github')">
|
||||
Disable
|
||||
</button>
|
||||
<div id="clear-github"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.github %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('github')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card">
|
||||
|
@ -94,13 +94,26 @@
|
|||
{% endif %}
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
<div id="clear-notion"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('notion')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="clear-notion"
|
||||
class="card-action-row"
|
||||
style="display: {% if not current_model_state.notion %}none{% endif %}">
|
||||
<button class="card-button" onclick="clearContentType('notion')">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="general-settings section">
|
||||
<div id="status" style="display: none;"></div>
|
||||
</div>
|
||||
<div class="section finalize-actions general-settings">
|
||||
<div class="section-cards">
|
||||
<div class="finalize-buttons">
|
||||
<button id="configure" type="submit" title="Update index with the latest changes">💾 Save All</button>
|
||||
</div>
|
||||
<div class="finalize-buttons">
|
||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -221,23 +234,7 @@
|
|||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div class="section general-settings">
|
||||
<div id="results-count" title="Number of items to show in search and use for chat response">
|
||||
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
|
||||
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
|
||||
</div>
|
||||
<div id="status" style="display: none;"></div>
|
||||
</div>
|
||||
<div class="section finalize-actions general-settings">
|
||||
<div class="section-cards">
|
||||
<div class="finalize-buttons">
|
||||
<button id="configure" type="submit" title="Update index with the latest changes">⚙️ Configure</button>
|
||||
</div>
|
||||
<div class="finalize-buttons">
|
||||
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section"></div>
|
||||
</div>
|
||||
<script>
|
||||
|
||||
|
@ -329,11 +326,11 @@
|
|||
event.preventDefault();
|
||||
updateIndex(
|
||||
force=false,
|
||||
successText="Configured successfully!",
|
||||
successText="Saved!",
|
||||
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||
button=configure,
|
||||
loadingText="Configuring...",
|
||||
emoji="⚙️");
|
||||
loadingText="Saving...",
|
||||
emoji="💾");
|
||||
});
|
||||
|
||||
var reinitialize = document.getElementById("reinitialize");
|
||||
|
@ -341,7 +338,7 @@
|
|||
event.preventDefault();
|
||||
updateIndex(
|
||||
force=true,
|
||||
successText="Reinitialized successfully!",
|
||||
successText="Reinitialized!",
|
||||
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
|
||||
button=reinitialize,
|
||||
loadingText="Reinitializing...",
|
||||
|
@ -350,6 +347,7 @@
|
|||
|
||||
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
const original_html = button.innerHTML;
|
||||
button.disabled = true;
|
||||
button.innerHTML = emoji + " " + loadingText;
|
||||
fetch('/api/update?&client=web&force=' + force, {
|
||||
|
@ -361,15 +359,17 @@
|
|||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Success:', data);
|
||||
if (data.detail != null) {
|
||||
throw new Error(data.detail);
|
||||
}
|
||||
|
||||
document.getElementById("status").innerHTML = emoji + " " + successText;
|
||||
document.getElementById("status").style.display = "block";
|
||||
document.getElementById("status").style.display = "none";
|
||||
|
||||
button.disabled = false;
|
||||
button.innerHTML = '✅ Done!';
|
||||
button.innerHTML = `✅ ${successText}`;
|
||||
setTimeout(function() {
|
||||
button.innerHTML = original_html;
|
||||
}, 2000);
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Error:', error);
|
||||
|
@ -377,6 +377,9 @@
|
|||
document.getElementById("status").style.display = "block";
|
||||
button.disabled = false;
|
||||
button.innerHTML = '⚠️ Unsuccessful';
|
||||
setTimeout(function() {
|
||||
button.innerHTML = original_html;
|
||||
}, 2000);
|
||||
});
|
||||
|
||||
content_sources = ["computer", "github", "notion"];
|
||||
|
@ -400,26 +403,6 @@
|
|||
});
|
||||
}
|
||||
|
||||
// Setup the results count slider
|
||||
const resultsCountSlider = document.getElementById('results-count-slider');
|
||||
const resultsCountValue = document.getElementById('results-count-value');
|
||||
|
||||
// Set the initial value of the slider
|
||||
resultsCountValue.textContent = resultsCountSlider.value;
|
||||
|
||||
// Store the slider value in localStorage when it changes
|
||||
resultsCountSlider.addEventListener('input', () => {
|
||||
resultsCountValue.textContent = resultsCountSlider.value;
|
||||
localStorage.setItem('khojResultsCount', resultsCountSlider.value);
|
||||
});
|
||||
|
||||
// Get the slider value from localStorage on page load
|
||||
const storedResultsCount = localStorage.getItem('khojResultsCount');
|
||||
if (storedResultsCount) {
|
||||
resultsCountSlider.value = storedResultsCount;
|
||||
resultsCountValue.textContent = storedResultsCount;
|
||||
}
|
||||
|
||||
function generateAPIKey() {
|
||||
const apiKeyList = document.getElementById("api-key-list");
|
||||
fetch('/auth/token', {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
<span class="card-title-text">Files</span>
|
||||
<div class="instructions">
|
||||
<p class="card-description">Manage files from your computer</p>
|
||||
<p id="get-desktop-client" class="card-description">Download the <a href="https://download.khoj.dev">Khoj Desktop app</a> to sync documents from your computer</p>
|
||||
<p id="get-desktop-client" class="card-description">Download the <a href="https://khoj.dev/downloads">Khoj Desktop app</a> to sync documents from your computer</p>
|
||||
</div>
|
||||
</h2>
|
||||
<div class="section-manage-files">
|
||||
|
|
|
@ -46,6 +46,9 @@
|
|||
</div>
|
||||
</div>
|
||||
<style>
|
||||
td {
|
||||
padding: 10px 0;
|
||||
}
|
||||
div.repo {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
|
@ -124,6 +127,11 @@
|
|||
return;
|
||||
}
|
||||
|
||||
const submitButton = document.getElementById("submit");
|
||||
submitButton.disabled = true;
|
||||
submitButton.innerHTML = "Saving...";
|
||||
|
||||
// Save Github config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content-source/github', {
|
||||
method: 'POST',
|
||||
|
@ -137,15 +145,40 @@
|
|||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save settings";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
return;
|
||||
});
|
||||
|
||||
// Index Github content on server
|
||||
fetch('/api/update?t=github')
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
|
||||
.then(data => {
|
||||
if (data["status"] == "ok") {
|
||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
} else {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
}
|
||||
document.getElementById("success").style.display = "none";
|
||||
submitButton.innerHTML = "✅ Successfully updated";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
})
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save content";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
});
|
||||
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
|
|
@ -41,6 +41,11 @@
|
|||
return;
|
||||
}
|
||||
|
||||
const submitButton = document.getElementById("submit");
|
||||
submitButton.disabled = true;
|
||||
submitButton.innerHTML = "Saving...";
|
||||
|
||||
// Save Notion config on server
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/content-source/notion', {
|
||||
method: 'POST',
|
||||
|
@ -53,15 +58,39 @@
|
|||
})
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save settings";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
return;
|
||||
});
|
||||
|
||||
// Index Notion content on server
|
||||
fetch('/api/update?t=notion')
|
||||
.then(response => response.json())
|
||||
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
|
||||
.then(data => {
|
||||
if (data["status"] == "ok") {
|
||||
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
} else {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to update settings.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
}
|
||||
document.getElementById("success").style.display = "none";
|
||||
submitButton.innerHTML = "✅ Successfully updated";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
})
|
||||
.catch(error => {
|
||||
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
submitButton.innerHTML = "⚠️ Failed to save content";
|
||||
setTimeout(function() {
|
||||
submitButton.innerHTML = "Save";
|
||||
submitButton.disabled = false;
|
||||
}, 2000);
|
||||
});
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
||||
|
|
|
@ -189,7 +189,6 @@
|
|||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log(data);
|
||||
document.getElementById("results").innerHTML = render_results(data, query, type);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -56,6 +56,7 @@ locale.setlocale(locale.LC_ALL, "")
|
|||
from khoj.configure import configure_routes, initialize_server, configure_middleware
|
||||
from khoj.utils import state
|
||||
from khoj.utils.cli import cli
|
||||
from khoj.utils.initialization import initialization
|
||||
|
||||
# Setup Logger
|
||||
rich_handler = RichHandler(rich_tracebacks=True)
|
||||
|
@ -74,8 +75,7 @@ def run(should_start_server=True):
|
|||
args = cli(state.cli_args)
|
||||
set_state(args)
|
||||
|
||||
# Create app directory, if it doesn't exist
|
||||
state.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"🚒 Initializing Khoj v{state.khoj_version}")
|
||||
|
||||
# Set Logging Level
|
||||
if args.verbose == 0:
|
||||
|
@ -83,6 +83,11 @@ def run(should_start_server=True):
|
|||
elif args.verbose >= 1:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
initialization()
|
||||
|
||||
# Create app directory, if it doesn't exist
|
||||
state.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set Log File
|
||||
fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
|
||||
fh.setLevel(logging.DEBUG)
|
||||
|
@ -97,7 +102,7 @@ def run(should_start_server=True):
|
|||
configure_routes(app)
|
||||
|
||||
# Mount Django and Static Files
|
||||
app.mount("/django", django_app, name="django")
|
||||
app.mount("/server", django_app, name="server")
|
||||
static_dir = "static"
|
||||
if not os.path.exists(static_dir):
|
||||
os.mkdir(static_dir)
|
||||
|
|
|
@ -55,10 +55,10 @@ def extract_questions_offline(
|
|||
last_year = datetime.now().year - 1
|
||||
last_christmas_date = f"{last_year}-12-25"
|
||||
next_christmas_date = f"{datetime.now().year}-12-25"
|
||||
system_prompt = prompts.extract_questions_system_prompt_llamav2.format(
|
||||
message=(prompts.system_prompt_message_extract_questions_llamav2)
|
||||
system_prompt = prompts.system_prompt_extract_questions_gpt4all.format(
|
||||
message=(prompts.system_prompt_message_extract_questions_gpt4all)
|
||||
)
|
||||
example_questions = prompts.extract_questions_llamav2_sample.format(
|
||||
example_questions = prompts.extract_questions_gpt4all_sample.format(
|
||||
query=text,
|
||||
chat_history=chat_history,
|
||||
current_date=current_date,
|
||||
|
@ -150,14 +150,14 @@ def converse_offline(
|
|||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
|
||||
conversation_primer = user_query
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
||||
conversation_primer = prompts.notes_conversation_gpt4all.format(
|
||||
query=user_query, references=compiled_references_message
|
||||
)
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
prompts.system_prompt_message_llamav2,
|
||||
prompts.system_prompt_message_gpt4all,
|
||||
conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
@ -183,16 +183,16 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
|||
conversation_history = messages[1:-1]
|
||||
|
||||
formatted_messages = [
|
||||
prompts.chat_history_llamav2_from_assistant.format(message=message.content)
|
||||
prompts.khoj_message_gpt4all.format(message=message.content)
|
||||
if message.role == "assistant"
|
||||
else prompts.chat_history_llamav2_from_user.format(message=message.content)
|
||||
else prompts.user_message_gpt4all.format(message=message.content)
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
stop_words = ["<s>"]
|
||||
chat_history = "".join(formatted_messages)
|
||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
|
||||
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
|
||||
state.chat_lock.acquire()
|
||||
|
|
|
@ -20,27 +20,6 @@ from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
|
||||
"""
|
||||
Summarize conversation session using the specified OpenAI chat model
|
||||
"""
|
||||
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {messages}")
|
||||
response = completion_with_backoff(
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
model_kwargs={"stop": ['"""'], "frequency_penalty": 0.2},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return str(response.content).replace("\n\n", "")
|
||||
|
||||
|
||||
def extract_questions(
|
||||
text,
|
||||
model: Optional[str] = "gpt-4",
|
||||
|
@ -131,16 +110,14 @@ def converse(
|
|||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
|
||||
conversation_primer = prompts.general_conversation.format(current_date=current_date, query=user_query)
|
||||
conversation_primer = prompts.general_conversation.format(query=user_query)
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation.format(
|
||||
current_date=current_date, query=user_query, references=compiled_references
|
||||
)
|
||||
conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
prompts.personality.format(),
|
||||
prompts.personality.format(current_date=current_date),
|
||||
conversation_log,
|
||||
model,
|
||||
max_prompt_size,
|
||||
|
@ -157,4 +134,5 @@ def converse(
|
|||
temperature=temperature,
|
||||
openai_api_key=api_key,
|
||||
completion_func=completion_func,
|
||||
model_kwargs={"stop": ["Notes:\n["]},
|
||||
)
|
||||
|
|
|
@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
|
|||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
|
@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
|||
model_name=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
model_kwargs=model_kwargs,
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
|
|
|
@ -4,30 +4,44 @@ from langchain.prompts import PromptTemplate
|
|||
|
||||
## Personality
|
||||
## --
|
||||
personality = PromptTemplate.from_template("You are Khoj, a smart, inquisitive and helpful personal assistant.")
|
||||
personality = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
Use your general knowledge and the past conversation with the user as context to inform your responses.
|
||||
You were created by Khoj Inc. with the following capabilities:
|
||||
|
||||
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
||||
- You cannot set reminders.
|
||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||
|
||||
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
|
||||
Today is {current_date} in UTC.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## General Conversation
|
||||
## --
|
||||
general_conversation = PromptTemplate.from_template(
|
||||
"""
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
Current Date: {current_date}
|
||||
|
||||
Question: {query}
|
||||
{query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
no_notes_found = PromptTemplate.from_template(
|
||||
"""
|
||||
I'm sorry, I couldn't find any relevant notes to respond to your message.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
## Conversation Prompts for GPT4All Models
|
||||
## --
|
||||
system_prompt_message_gpt4all = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
If you do not know the answer, say 'I don't know.'"""
|
||||
|
||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Write the question as if you can search for the answer on the user's personal notes.
|
||||
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
||||
- Add as much context from the previous questions and notes as required into your search queries.
|
||||
|
@ -35,61 +49,47 @@ system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and i
|
|||
What follow-up questions, if any, will you need to ask to answer the user's question?
|
||||
"""
|
||||
|
||||
system_prompt_llamav2 = PromptTemplate.from_template(
|
||||
system_prompt_gpt4all = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] <<SYS>>
|
||||
{message}
|
||||
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
|
||||
)
|
||||
|
||||
extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
|
||||
system_prompt_extract_questions_gpt4all = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] <<SYS>>
|
||||
{message}
|
||||
<</SYS>>[/INST]</s>"""
|
||||
)
|
||||
|
||||
general_conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] {query} [/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
chat_history_llamav2_from_user = PromptTemplate.from_template(
|
||||
user_message_gpt4all = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] {message} [/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
chat_history_llamav2_from_assistant = PromptTemplate.from_template(
|
||||
khoj_message_gpt4all = PromptTemplate.from_template(
|
||||
"""
|
||||
{message}</s>
|
||||
""".strip()
|
||||
)
|
||||
|
||||
conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] {query} [/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Notes Conversation
|
||||
## --
|
||||
notes_conversation = PromptTemplate.from_template(
|
||||
"""
|
||||
Using my personal notes and our past conversations as context, answer the following question.
|
||||
Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||
These questions should end with a question mark.
|
||||
Current Date: {current_date}
|
||||
Use my personal notes and our past conversations to inform your response.
|
||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
|
||||
|
||||
Notes:
|
||||
{references}
|
||||
|
||||
Question: {query}
|
||||
Query: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
notes_conversation_llamav2 = PromptTemplate.from_template(
|
||||
notes_conversation_gpt4all = PromptTemplate.from_template(
|
||||
"""
|
||||
User's Notes:
|
||||
{references}
|
||||
|
@ -98,13 +98,6 @@ Question: {query}
|
|||
)
|
||||
|
||||
|
||||
## Summarize Chat
|
||||
## --
|
||||
summarize_chat = PromptTemplate.from_template(
|
||||
f"{personality.format()} Summarize the conversation from your first person perspective"
|
||||
)
|
||||
|
||||
|
||||
## Summarize Notes
|
||||
## --
|
||||
summarize_notes = PromptTemplate.from_template(
|
||||
|
@ -132,7 +125,10 @@ Question: {user_query}
|
|||
Answer (in second person):"""
|
||||
)
|
||||
|
||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
||||
|
||||
## Extract Questions
|
||||
## --
|
||||
extract_questions_gpt4all_sample = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
|
||||
<s>[INST] How was my trip to Cambodia? [/INST]
|
||||
|
@ -157,8 +153,6 @@ Use these notes from the user's previous conversations to provide a response:
|
|||
)
|
||||
|
||||
|
||||
## Extract Questions
|
||||
## --
|
||||
extract_questions = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
|
||||
|
|
|
@ -27,5 +27,5 @@ class CrossEncoderModel:
|
|||
|
||||
def predict(self, query, hits: List[SearchResponse]):
|
||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
||||
cross_scores = self.cross_encoder_model.predict(cross__inp)
|
||||
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
||||
return cross_scores
|
||||
|
|
|
@ -7,7 +7,7 @@ import json
|
|||
from typing import List, Optional, Union, Any
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
@ -36,6 +36,7 @@ from khoj.routers.helpers import (
|
|||
agenerate_chat_response,
|
||||
update_telemetry_state,
|
||||
is_ready_to_chat,
|
||||
ApiUserRateLimiter,
|
||||
)
|
||||
from khoj.processor.conversation.prompts import help_message
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
|
@ -177,11 +178,15 @@ async def set_content_config_github_data(
|
|||
|
||||
user = request.user.object
|
||||
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
try:
|
||||
await adapters.set_user_github_config(
|
||||
user=user,
|
||||
pat_token=updated_config.pat_token,
|
||||
repos=updated_config.repos,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -205,10 +210,14 @@ async def set_content_config_notion_data(
|
|||
|
||||
user = request.user.object
|
||||
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
try:
|
||||
await adapters.set_notion_config(
|
||||
user=user,
|
||||
token=updated_config.token,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to set Github config")
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -348,7 +357,7 @@ async def search(
|
|||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = SearchType.All,
|
||||
r: Optional[bool] = False,
|
||||
score_threshold: Optional[Union[float, None]] = None,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
|
@ -367,12 +376,12 @@ async def search(
|
|||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n or 5
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
max_distance = max_distance if max_distance is not None else math.inf
|
||||
search_futures: List[concurrent.futures.Future] = []
|
||||
|
||||
# return cached results, if available
|
||||
if user:
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
|
||||
if query_cache_key in state.query_cache[user.uuid]:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[user.uuid][query_cache_key]
|
||||
|
@ -409,8 +418,7 @@ async def search(
|
|||
user_query,
|
||||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
rank_results=r or False,
|
||||
score_threshold=score_threshold,
|
||||
max_distance=max_distance,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -423,7 +431,6 @@ async def search(
|
|||
results_count,
|
||||
state.search_models.image_search,
|
||||
state.content_index.image,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -446,11 +453,10 @@ async def search(
|
|||
# Collate results
|
||||
results += text_search.collate_results(hits, dedupe=dedupe)
|
||||
|
||||
if r:
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
|
||||
else:
|
||||
# Sort results across all content types and take top results
|
||||
results = sorted(results, key=lambda x: float(x.score))[:results_count]
|
||||
results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
|
||||
:results_count
|
||||
]
|
||||
|
||||
# Cache results
|
||||
if user:
|
||||
|
@ -575,11 +581,14 @@ async def chat(
|
|||
request: Request,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.15,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user = request.user.object
|
||||
|
||||
|
@ -591,7 +600,7 @@ async def chat(
|
|||
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
|
||||
|
||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||
request, meta_log, q, (n or 5), conversation_command
|
||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
|
||||
)
|
||||
|
||||
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
|
||||
|
@ -606,7 +615,7 @@ async def chat(
|
|||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response = await agenerate_chat_response(
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
compiled_references,
|
||||
|
@ -615,6 +624,19 @@ async def chat(
|
|||
user,
|
||||
)
|
||||
|
||||
chat_metadata.update({"conversation_command": conversation_command.value})
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
metadata=chat_metadata,
|
||||
)
|
||||
|
||||
if llm_response is None:
|
||||
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
||||
|
||||
|
@ -634,16 +656,6 @@ async def chat(
|
|||
|
||||
response_obj = {"response": actual_response, "context": compiled_references}
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
client=client,
|
||||
user_agent=user_agent,
|
||||
referer=referer,
|
||||
host=host,
|
||||
)
|
||||
|
||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
|
@ -652,6 +664,7 @@ async def extract_references_and_questions(
|
|||
meta_log: dict,
|
||||
q: str,
|
||||
n: int,
|
||||
d: float,
|
||||
conversation_type: ConversationCommand = ConversationCommand.Default,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
@ -663,7 +676,7 @@ async def extract_references_and_questions(
|
|||
if conversation_type == ConversationCommand.General:
|
||||
return compiled_references, inferred_queries, q
|
||||
|
||||
if not sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||
logger.warning(
|
||||
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
|
||||
)
|
||||
|
@ -712,7 +725,7 @@ async def extract_references_and_questions(
|
|||
request=request,
|
||||
n=n_items,
|
||||
r=True,
|
||||
score_threshold=-5.0,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@ from google.auth.transport import requests as google_requests
|
|||
|
||||
# Internal Packages
|
||||
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
|
||||
from database.models import KhojApiUser
|
||||
from khoj.routers.helpers import update_telemetry_state
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
|
@ -100,6 +100,16 @@ async def auth(request: Request):
|
|||
if khoj_user:
|
||||
request.session["user"] = dict(idinfo)
|
||||
|
||||
if not khoj_user.last_login:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(khoj_user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"New User Created: {khoj_user.uuid}")
|
||||
RedirectResponse(url="/?status=welcome")
|
||||
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
|
||||
|
|
|
@ -1,21 +1,27 @@
|
|||
import logging
|
||||
# Standard Packages
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
from time import time
|
||||
from typing import Iterator, List, Optional, Union, Tuple, Dict
|
||||
|
||||
# External Packages
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||
from database.models import KhojUser
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import ConversationAdapters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
@ -61,12 +67,15 @@ def update_telemetry_state(
|
|||
metadata: Optional[dict] = None,
|
||||
):
|
||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||
subscription: Subscription = user.subscription if user and user.subscription else None
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
"server_id": str(user.uuid) if user else None,
|
||||
"subscription_type": subscription.type if subscription else None,
|
||||
"is_recurring": subscription.is_recurring if subscription else None,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
|
@ -109,7 +118,7 @@ def generate_chat_response(
|
|||
inferred_queries: List[str] = [],
|
||||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||
user: KhojUser = None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
chat_response: str,
|
||||
|
@ -132,6 +141,8 @@ def generate_chat_response(
|
|||
chat_response = None
|
||||
logger.debug(f"Conversation Type: {conversation_command.name}")
|
||||
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
|
@ -148,8 +159,8 @@ def generate_chat_response(
|
|||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
|
||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
chat_response = converse_offline(
|
||||
|
@ -179,8 +190,33 @@ def generate_chat_response(
|
|||
tokenizer_name=conversation_config.tokenizer,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return chat_response
|
||||
return chat_response, metadata
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, window: int):
|
||||
self.requests = requests
|
||||
self.window = window
|
||||
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def __call__(self, request: Request):
|
||||
user: KhojUser = request.user.object
|
||||
user_requests = self.cache[user.uuid]
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = time() - self.window
|
||||
while user_requests and user_requests[0] < cutoff:
|
||||
user_requests.pop(0)
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if len(user_requests) >= self.requests:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||
|
||||
# Add the current request to the cache
|
||||
user_requests.append(time())
|
||||
|
|
|
@ -146,7 +146,7 @@ def extract_metadata(image_name):
|
|||
|
||||
|
||||
async def query(
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
|
||||
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
|
||||
):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
|
@ -167,7 +167,8 @@ async def query(
|
|||
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
|
||||
with timer("Search Time", logger):
|
||||
image_hits = {
|
||||
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
|
||||
# Map scores to distance metric by multiplying by -1
|
||||
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
|
||||
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
|
||||
}
|
||||
|
||||
|
@ -204,7 +205,7 @@ async def query(
|
|||
]
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
||||
hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
|
||||
|
||||
# Sort the images based on their combined metadata, image scores
|
||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
|
|
@ -104,8 +104,7 @@ async def query(
|
|||
raw_query: str,
|
||||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
rank_results: bool = False,
|
||||
score_threshold: float = -math.inf,
|
||||
max_distance: float = math.inf,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
|
||||
|
@ -127,6 +126,7 @@ async def query(
|
|||
max_results=top_k,
|
||||
file_type_filter=file_type,
|
||||
raw_query=raw_query,
|
||||
max_distance=max_distance,
|
||||
).all()
|
||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||
|
||||
|
@ -177,12 +177,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
|||
)
|
||||
|
||||
|
||||
def rerank_and_sort_results(hits, query):
|
||||
def rerank_and_sort_results(hits, query, rank_results):
|
||||
# If we have more than one result and reranking is enabled
|
||||
rank_results = rank_results and len(list(hits)) > 1
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
hits = cross_encoder_score(query, hits)
|
||||
if rank_results:
|
||||
hits = cross_encoder_score(query, hits)
|
||||
|
||||
# Sort results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results=True, hits=hits)
|
||||
hits = sort_results(rank_results=rank_results, hits=hits)
|
||||
|
||||
return hits
|
||||
|
||||
|
@ -217,9 +221,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
|||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model.predict(query, hits)
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]["cross_score"] = cross_scores[idx]
|
||||
hits[idx]["cross_score"] = 1 - cross_scores[idx]
|
||||
|
||||
return hits
|
||||
|
||||
|
@ -227,7 +231,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
|
|||
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
|
||||
"""Order results by cross-encoder score followed by bi-encoder score"""
|
||||
with timer("Rank Time", logger, state.device):
|
||||
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
|
||||
hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
|
||||
hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
|
||||
return hits
|
||||
|
|
|
@ -6,6 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
|
|||
app_env_filepath = "~/.khoj/env"
|
||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||
content_directory = "~/.khoj/content/"
|
||||
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
|
||||
empty_config = {
|
||||
"search-type": {
|
||||
|
|
98
src/khoj/utils/initialization.py
Normal file
98
src/khoj/utils/initialization.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from database.models import (
|
||||
KhojUser,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
ChatModelOptions,
|
||||
)
|
||||
|
||||
from khoj.utils.constants import default_offline_chat_model
|
||||
|
||||
from database.adapters import ConversationAdapters
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialization():
|
||||
def _create_admin_user():
|
||||
logger.info(
|
||||
"👩✈️ Setting up admin user. These credentials will allow you to configure your server at /server/admin."
|
||||
)
|
||||
email_addr = os.getenv("KHOJ_ADMIN_EMAIL") or input("Email: ")
|
||||
password = os.getenv("KHOJ_ADMIN_PASSWORD") or input("Password: ")
|
||||
admin_user = KhojUser.objects.create_superuser(email=email_addr, username=email_addr, password=password)
|
||||
logger.info(f"👩✈️ Created admin user: {admin_user.email}")
|
||||
|
||||
def _create_chat_configuration():
|
||||
logger.info(
|
||||
"🗣️ Configure chat models available to your server. You can always update these at /server/admin using the credentials of your admin account"
|
||||
)
|
||||
try:
|
||||
# Some environments don't support interactive input. We catch the exception and return if that's the case. The admin can still configure their settings from the admin page.
|
||||
input()
|
||||
except EOFError:
|
||||
return
|
||||
|
||||
try:
|
||||
# Note: gpt4all package is not available on all devices.
|
||||
# So ensure gpt4all package is installed before continuing this step.
|
||||
import gpt4all
|
||||
|
||||
use_offline_model = input("Use offline chat model? (y/n): ")
|
||||
if use_offline_model == "y":
|
||||
logger.info("🗣️ Setting up offline chat model")
|
||||
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
|
||||
|
||||
offline_chat_model = input(
|
||||
f"Enter the name of the offline chat model you want to use, based on the models in HuggingFace (press enter to use the default: {default_offline_chat_model}): "
|
||||
)
|
||||
if offline_chat_model == "":
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=default_offline_chat_model, model_type=ChatModelOptions.ModelType.OFFLINE
|
||||
)
|
||||
else:
|
||||
max_tokens = input("Enter the maximum number of tokens to use for the offline chat model:")
|
||||
tokenizer = input("Enter the tokenizer to use for the offline chat model:")
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=offline_chat_model,
|
||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||
max_prompt_size=max_tokens,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.warning("Offline models are not supported on this device.")
|
||||
|
||||
use_openai_model = input("Use OpenAI chat model? (y/n): ")
|
||||
|
||||
if use_openai_model == "y":
|
||||
logger.info("🗣️ Setting up OpenAI chat model")
|
||||
api_key = input("Enter your OpenAI API key: ")
|
||||
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
|
||||
openai_chat_model = input("Enter the name of the OpenAI chat model you want to use: ")
|
||||
max_tokens = input("Enter the maximum number of tokens to use for the OpenAI chat model:")
|
||||
ChatModelOptions.objects.create(
|
||||
chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
logger.info("🗣️ Chat model configuration complete")
|
||||
|
||||
admin_user = KhojUser.objects.filter(is_staff=True).first()
|
||||
if admin_user is None:
|
||||
while True:
|
||||
try:
|
||||
_create_admin_user()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
|
||||
|
||||
chat_config = ConversationAdapters.get_default_conversation_config()
|
||||
if admin_user is None and chat_config is None:
|
||||
while True:
|
||||
try:
|
||||
_create_chat_configuration()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)
|
|
@ -43,6 +43,7 @@ from tests.helpers import (
|
|||
OpenAIProcessorConversationConfigFactory,
|
||||
OfflineChatProcessorConversationConfigFactory,
|
||||
UserConversationProcessorConfigFactory,
|
||||
SubscriptionFactory,
|
||||
)
|
||||
|
||||
|
||||
|
@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
|
|||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user():
|
||||
return UserFactory()
|
||||
user = UserFactory()
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
@ -78,11 +81,31 @@ def default_user2():
|
|||
if KhojUser.objects.filter(username="default").exists():
|
||||
return KhojUser.objects.get(username="default")
|
||||
|
||||
return KhojUser.objects.create(
|
||||
user = KhojUser.objects.create(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user3():
|
||||
"""
|
||||
This user should not have any data associated with it
|
||||
"""
|
||||
if KhojUser.objects.filter(username="default3").exists():
|
||||
return KhojUser.objects.get(username="default3")
|
||||
|
||||
user = KhojUser.objects.create(
|
||||
username="default3",
|
||||
email="default3@example.com",
|
||||
password="default3",
|
||||
)
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
@ -111,6 +134,19 @@ def api_user2(default_user2):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def api_user3(default_user3):
|
||||
if KhojApiUser.objects.filter(user=default_user3).exists():
|
||||
return KhojApiUser.objects.get(user=default_user3)
|
||||
|
||||
return KhojApiUser.objects.create(
|
||||
user=default_user3,
|
||||
name="api-key",
|
||||
token="kk-diff-secret-3",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
|
@ -206,7 +242,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
|
|||
OpenAIProcessorConversationConfigFactory()
|
||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
||||
|
||||
state.anonymous_mode = False
|
||||
state.anonymous_mode = True
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
@ -224,7 +260,9 @@ def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUs
|
|||
|
||||
# Initialize Processor from Config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||
OpenAIProcessorConversationConfigFactory()
|
||||
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
|
||||
|
||||
state.anonymous_mode = True
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from database.models import (
|
|||
OpenAIProcessorConversationConfig,
|
||||
UserConversationConfig,
|
||||
Conversation,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
|
||||
|
@ -68,3 +69,13 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
|||
model = Conversation
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
|
||||
|
||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Subscription
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
type = "standard"
|
||||
is_recurring = False
|
||||
renewal_date = "2100-04-01"
|
||||
|
|
|
@ -16,7 +16,7 @@ from khoj.utils.state import search_models, content_index, config
|
|||
from khoj.search_type import text_search, image_search
|
||||
from khoj.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from khoj.processor.org_mode.org_to_entries import OrgToEntries
|
||||
from database.models import KhojUser
|
||||
from database.models import KhojUser, KhojApiUser
|
||||
from database.adapters import EntryAdapters
|
||||
|
||||
|
||||
|
@ -351,6 +351,24 @@ def test_different_user_data_not_accessed(client, sample_org_data, default_user:
|
|||
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiUser):
|
||||
# Arrange
|
||||
token = api_user3.token
|
||||
headers = {"Authorization": "Bearer " + token}
|
||||
user_query = quote("How to git install application?")
|
||||
|
||||
# Act
|
||||
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual response has no data as the default_user3, though other users have data
|
||||
assert len(response.json()) == 0
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
def get_sample_files_data():
|
||||
return [
|
||||
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
|
||||
|
|
|
@ -307,6 +307,8 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
|
|||
"which one is",
|
||||
"which of namita's sons",
|
||||
"the birth order",
|
||||
"provide more context",
|
||||
"provide me with more context",
|
||||
]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
|
|
Loading…
Reference in a new issue