Resolve merge conflicts in auth.py with remove KhojApiUser import

This commit is contained in:
sabaimran 2023-11-15 17:32:53 -08:00
commit 6b17aeb32d
35 changed files with 605 additions and 266 deletions

View file

@ -10,7 +10,15 @@ services:
POSTGRES_DB: postgres POSTGRES_DB: postgres
volumes: volumes:
- khoj_db:/var/lib/postgresql/data/ - khoj_db:/var/lib/postgresql/data/
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 30s
timeout: 10s
retries: 5
server: server:
depends_on:
database:
condition: service_healthy
# Use the following line to use the latest version of khoj. Otherwise, it will build from source. # Use the following line to use the latest version of khoj. Otherwise, it will build from source.
image: ghcr.io/khoj-ai/khoj:latest image: ghcr.io/khoj-ai/khoj:latest
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image. # Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
@ -24,20 +32,6 @@ services:
- "42110:42110" - "42110:42110"
working_dir: /app working_dir: /app
volumes: volumes:
- .:/app
# These mounted volumes hold the raw data that should be indexed for search.
# The path in your local directory (left hand side)
# points to the files you want to index.
# The path of the mounted directory (right hand side),
# must match the path prefix in your config file.
- ./tests/data/org/:/data/org/
- ./tests/data/images/:/data/images/
- ./tests/data/markdown/:/data/markdown/
- ./tests/data/pdf/:/data/pdf/
# Embeddings and models are populated after the first run
# You can set these volumes to point to empty directories on host
- ./tests/data/embeddings/:/root/.khoj/content/
- ./tests/data/models/:/root/.khoj/search/
- khoj_config:/root/.khoj/ - khoj_config:/root/.khoj/
- khoj_models:/root/.cache/torch/sentence_transformers - khoj_models:/root/.cache/torch/sentence_transformers
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/ # Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
@ -47,9 +41,11 @@ services:
- POSTGRES_PASSWORD=postgres - POSTGRES_PASSWORD=postgres
- POSTGRES_HOST=database - POSTGRES_HOST=database
- POSTGRES_PORT=5432 - POSTGRES_PORT=5432
- GOOGLE_CLIENT_SECRET=bar - KHOJ_DJANGO_SECRET_KEY=secret
- GOOGLE_CLIENT_ID=foo - KHOJ_DEBUG=True
command: --host="0.0.0.0" --port=42110 -vv - KHOJ_ADMIN_EMAIL=username@example.com
- KHOJ_ADMIN_PASSWORD=password
command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode
volumes: volumes:

View file

@ -1,3 +1,4 @@
import math
from typing import Optional, Type, TypeVar, List from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta from datetime import date, datetime, timedelta
import secrets import secrets
@ -99,6 +100,8 @@ async def create_google_user(token: dict) -> KhojUser:
user=user, user=user,
) )
await Subscription.objects.acreate(user=user, type="trial")
return user return user
@ -433,12 +436,19 @@ class EntryAdapters:
@staticmethod @staticmethod
def search_with_embeddings( def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None user: KhojUser,
embeddings: Tensor,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
): ):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter) relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate( relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings) distance=CosineDistance("embeddings", embeddings)
) )
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
if file_type_filter: if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter) relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance") relevant_entries = relevant_entries.order_by("distance")

View file

@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
Subscription,
) )
admin.site.register(KhojUser, UserAdmin) admin.site.register(KhojUser, UserAdmin)
@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(Subscription)

View file

@ -0,0 +1,21 @@
# Generated by Django 4.2.5 on 2023-11-11 05:39
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0014_alter_googleuser_picture"),
]
operations = [
migrations.AlterField(
model_name="subscription",
name="user",
field=models.OneToOneField(
on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
),
),
]

View file

@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-11-11 06:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0015_alter_subscription_user"),
]
operations = [
migrations.AlterField(
model_name="subscription",
name="renewal_date",
field=models.DateTimeField(blank=True, default=None, null=True),
),
]

View file

@ -51,10 +51,10 @@ class Subscription(BaseModel):
TRIAL = "trial" TRIAL = "trial"
STANDARD = "standard" STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL) type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False) is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None) renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class NotionConfig(BaseModel): class NotionConfig(BaseModel):

View file

@ -577,12 +577,12 @@
cursor: pointer; cursor: pointer;
transition: background 0.2s ease-in-out; transition: background 0.2s ease-in-out;
text-align: left; text-align: left;
max-height: 50px; max-height: 75px;
transition: max-height 0.3s ease-in-out; transition: max-height 0.3s ease-in-out;
overflow: hidden; overflow: hidden;
} }
button.reference-button.expanded { button.reference-button.expanded {
max-height: 200px; max-height: none;
} }
button.reference-button::before { button.reference-button::before {

View file

@ -198,12 +198,6 @@ khojKeyInput.addEventListener('blur', async () => {
khojKeyInput.value = token; khojKeyInput.value = token;
}); });
const syncButton = document.getElementById('sync-data');
syncButton.addEventListener('click', async () => {
loadingBar.style.display = 'block';
await window.syncDataAPI.syncData(false);
});
const syncForceButton = document.getElementById('sync-force'); const syncForceButton = document.getElementById('sync-force');
syncForceButton.addEventListener('click', async () => { syncForceButton.addEventListener('click', async () => {
loadingBar.style.display = 'block'; loadingBar.style.display = 'block';

View file

@ -188,7 +188,6 @@
fetch(url, { headers }) fetch(url, { headers })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type); document.getElementById("results").innerHTML = render_results(data, query, type);
}); });
} }

View file

@ -1,4 +1,4 @@
import { Notice, Plugin } from 'obsidian'; import { Notice, Plugin, request } from 'obsidian';
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings' import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
import { KhojSearchModal } from 'src/search_modal' import { KhojSearchModal } from 'src/search_modal'
import { KhojChatModal } from 'src/chat_modal' import { KhojChatModal } from 'src/chat_modal'
@ -69,6 +69,25 @@ export default class Khoj extends Plugin {
async loadSettings() { async loadSettings() {
// Load khoj obsidian plugin settings // Load khoj obsidian plugin settings
this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData()); this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData());
// Check if khoj backend is configured, note if cannot connect to backend
let headers = { "Authorization": `Bearer ${this.settings.khojApiKey}` };
if (this.settings.khojUrl === "https://app.khoj.dev") {
if (this.settings.khojApiKey === "") {
new Notice(`Khoj API key is not configured. Please visit https://app.khoj.dev to get an API key.`);
return;
}
await request({ url: this.settings.khojUrl ,method: "GET", headers: headers })
.then(response => {
this.settings.connectedToBackend = true;
})
.catch(error => {
this.settings.connectedToBackend = false;
new Notice(`Ensure Khoj backend is running and Khoj URL is pointing to it in the plugin settings.\n\n${error}`);
});
}
} }
async saveSettings() { async saveSettings() {

View file

@ -28,7 +28,7 @@ from khoj.utils.config import (
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser from database.models import KhojUser, Subscription
from database.adapters import get_all_users from database.adapters import get_all_users
@ -54,27 +54,37 @@ class UserAuthenticationBackend(AuthenticationBackend):
def _initialize_default_user(self): def _initialize_default_user(self):
if not self.khojuser_manager.filter(username="default").exists(): if not self.khojuser_manager.filter(username="default").exists():
self.khojuser_manager.create_user( default_user = self.khojuser_manager.create_user(
username="default", username="default",
email="default@example.com", email="default@example.com",
password="default", password="default",
) )
Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01")
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user") current_user = request.session.get("user")
if current_user and current_user.get("email"): if current_user and current_user.get("email"):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() user = (
await self.khojuser_manager.filter(email=current_user.get("email"))
.prefetch_related("subscription")
.afirst()
)
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header # Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1] bearer_token = request.headers["Authorization"].split("Bearer ")[1]
# Get user owning token # Get user owning token
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst() user_with_token = (
await self.khojapiuser_manager.filter(token=bearer_token)
.select_related("user")
.prefetch_related("user__subscription")
.afirst()
)
if user_with_token: if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)

View file

@ -109,7 +109,7 @@
display: grid; display: grid;
grid-template-rows: repeat(3, 1fr); grid-template-rows: repeat(3, 1fr);
gap: 8px; gap: 8px;
padding: 24px 16px; padding: 24px 16px 8px;
width: 320px; width: 320px;
height: 180px; height: 180px;
background: var(--background-color); background: var(--background-color);
@ -121,7 +121,7 @@
div.finalize-buttons { div.finalize-buttons {
display: grid; display: grid;
gap: 8px; gap: 8px;
padding: 24px 16px; padding: 32px 0px 0px;
width: 320px; width: 320px;
border-radius: 4px; border-radius: 4px;
overflow: hidden; overflow: hidden;
@ -162,10 +162,13 @@
color: grey; color: grey;
font-size: 16px; font-size: 16px;
} }
.card-button-row { .card-description-row {
padding-top: 4px;
}
.card-action-row {
display: grid; display: grid;
grid-template-columns: auto; grid-auto-flow: row;
text-align: right; justify-content: left;
} }
.card-button { .card-button {
border: none; border: none;
@ -271,7 +274,9 @@
100% { transform: rotate(360deg); } 100% { transform: rotate(360deg); }
} }
#status {
padding-top: 32px;
}
div.finalize-actions { div.finalize-actions {
grid-auto-flow: column; grid-auto-flow: column;
grid-gap: 24px; grid-gap: 24px;
@ -287,6 +292,7 @@
select#chat-models { select#chat-models {
margin-bottom: 0; margin-bottom: 0;
padding: 8px;
} }
@ -343,6 +349,12 @@
width: auto; width: auto;
} }
#status {
padding-top: 12px;
}
div.finalize-actions {
padding: 12px 0 0;
}
div.finalize-buttons { div.finalize-buttons {
padding: 0; padding: 0;
} }

View file

@ -43,7 +43,7 @@ To get started, just start typing below. You can also type / to see a list of co
let escaped_ref = reference.replaceAll('"', '"'); let escaped_ref = reference.replaceAll('"', '"');
// Generate HTML for Chat Reference // Generate HTML for Chat Reference
let short_ref = escaped_ref.slice(0, 100); let short_ref = escaped_ref.slice(0, 140);
short_ref = short_ref.length < escaped_ref.length ? short_ref + "..." : short_ref; short_ref = short_ref.length < escaped_ref.length ? short_ref + "..." : short_ref;
let referenceButton = document.createElement('button'); let referenceButton = document.createElement('button');
referenceButton.innerHTML = short_ref; referenceButton.innerHTML = short_ref;
@ -205,8 +205,11 @@ To get started, just start typing below. You can also type / to see a list of co
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed // Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
const currentHTML = newResponseText.innerHTML; const currentHTML = newResponseText.innerHTML;
newResponseText.innerHTML = formatHTMLMessage(currentHTML); newResponseText.innerHTML = formatHTMLMessage(currentHTML);
newResponseText.appendChild(references); if (references != null) {
newResponseText.appendChild(references);
}
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
return; return;
} }
@ -265,7 +268,6 @@ To get started, just start typing below. You can also type / to see a list of co
}); });
} }
readStream(); readStream();
document.getElementById("chat-input").removeAttribute("disabled");
}); });
} }
@ -417,6 +419,9 @@ To get started, just start typing below. You can also type / to see a list of co
display: block; display: block;
} }
div.references {
padding-top: 8px;
}
div.reference { div.reference {
display: grid; display: grid;
grid-template-rows: auto; grid-template-rows: auto;
@ -447,12 +452,12 @@ To get started, just start typing below. You can also type / to see a list of co
cursor: pointer; cursor: pointer;
transition: background 0.2s ease-in-out; transition: background 0.2s ease-in-out;
text-align: left; text-align: left;
max-height: 50px; max-height: 75px;
transition: max-height 0.3s ease-in-out; transition: max-height 0.3s ease-in-out;
overflow: hidden; overflow: hidden;
} }
button.reference-button.expanded { button.reference-button.expanded {
max-height: 200px; max-height: none;
} }
button.reference-button::before { button.reference-button::before {

View file

@ -29,12 +29,12 @@
{% endif %} {% endif %}
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> <div id="clear-computer" class="card-action-row"
<div id="clear-computer" class="card-action-row" style="display: {% if not current_model_state.computer %}none{% endif %}">
style="display: {% if not current_model_state.computer %}none{% endif %}"> <button class="card-button" onclick="clearContentType('computer')">
<button class="card-button" onclick="clearContentType('computer')"> Disable
Disable </button>
</button> </div>
</div> </div>
</div> </div>
<div class="card"> <div class="card">
@ -61,13 +61,13 @@
{% endif %} {% endif %}
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
</div> <div id="clear-github"
<div id="clear-github" class="card-action-row"
class="card-action-row" style="display: {% if not current_model_state.github %}none{% endif %}">
style="display: {% if not current_model_state.github %}none{% endif %}"> <button class="card-button" onclick="clearContentType('github')">
<button class="card-button" onclick="clearContentType('github')"> Disable
Disable </button>
</button> </div>
</div> </div>
</div> </div>
<div class="card"> <div class="card">
@ -94,13 +94,26 @@
{% endif %} {% endif %}
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg> <svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a> </a>
<div id="clear-notion"
class="card-action-row"
style="display: {% if not current_model_state.notion %}none{% endif %}">
<button class="card-button" onclick="clearContentType('notion')">
Disable
</button>
</div>
</div> </div>
<div id="clear-notion" </div>
class="card-action-row" </div>
style="display: {% if not current_model_state.notion %}none{% endif %}"> <div class="general-settings section">
<button class="card-button" onclick="clearContentType('notion')"> <div id="status" style="display: none;"></div>
Disable </div>
</button> <div class="section finalize-actions general-settings">
<div class="section-cards">
<div class="finalize-buttons">
<button id="configure" type="submit" title="Update index with the latest changes">💾 Save All</button>
</div>
<div class="finalize-buttons">
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
</div> </div>
</div> </div>
</div> </div>
@ -221,23 +234,7 @@
</div> </div>
</div> </div>
{% endif %} {% endif %}
<div class="section general-settings"> <div class="section"></div>
<div id="results-count" title="Number of items to show in search and use for chat response">
<label for="results-count-slider">Results Count: <span id="results-count-value">5</span></label>
<input type="range" id="results-count-slider" name="results-count-slider" min="1" max="10" step="1" value="5">
</div>
<div id="status" style="display: none;"></div>
</div>
<div class="section finalize-actions general-settings">
<div class="section-cards">
<div class="finalize-buttons">
<button id="configure" type="submit" title="Update index with the latest changes">⚙️ Configure</button>
</div>
<div class="finalize-buttons">
<button id="reinitialize" type="submit" title="Regenerate index from scratch">🔄 Reinitialize</button>
</div>
</div>
</div>
</div> </div>
<script> <script>
@ -329,11 +326,11 @@
event.preventDefault(); event.preventDefault();
updateIndex( updateIndex(
force=false, force=false,
successText="Configured successfully!", successText="Saved!",
errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.", errorText="Unable to configure. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
button=configure, button=configure,
loadingText="Configuring...", loadingText="Saving...",
emoji="⚙️"); emoji="💾");
}); });
var reinitialize = document.getElementById("reinitialize"); var reinitialize = document.getElementById("reinitialize");
@ -341,7 +338,7 @@
event.preventDefault(); event.preventDefault();
updateIndex( updateIndex(
force=true, force=true,
successText="Reinitialized successfully!", successText="Reinitialized!",
errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.", errorText="Unable to reinitialize. Raise issue on Khoj <a href='https://github.com/khoj-ai/khoj/issues'>Github</a> or <a href='https://discord.gg/BDgyabRM6e'>Discord</a>.",
button=reinitialize, button=reinitialize,
loadingText="Reinitializing...", loadingText="Reinitializing...",
@ -350,6 +347,7 @@
function updateIndex(force, successText, errorText, button, loadingText, emoji) { function updateIndex(force, successText, errorText, button, loadingText, emoji) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
const original_html = button.innerHTML;
button.disabled = true; button.disabled = true;
button.innerHTML = emoji + " " + loadingText; button.innerHTML = emoji + " " + loadingText;
fetch('/api/update?&client=web&force=' + force, { fetch('/api/update?&client=web&force=' + force, {
@ -361,15 +359,17 @@
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
console.log('Success:', data);
if (data.detail != null) { if (data.detail != null) {
throw new Error(data.detail); throw new Error(data.detail);
} }
document.getElementById("status").innerHTML = emoji + " " + successText; document.getElementById("status").style.display = "none";
document.getElementById("status").style.display = "block";
button.disabled = false; button.disabled = false;
button.innerHTML = '✅ Done!'; button.innerHTML = `✅ ${successText}`;
setTimeout(function() {
button.innerHTML = original_html;
}, 2000);
}) })
.catch((error) => { .catch((error) => {
console.error('Error:', error); console.error('Error:', error);
@ -377,6 +377,9 @@
document.getElementById("status").style.display = "block"; document.getElementById("status").style.display = "block";
button.disabled = false; button.disabled = false;
button.innerHTML = '⚠️ Unsuccessful'; button.innerHTML = '⚠️ Unsuccessful';
setTimeout(function() {
button.innerHTML = original_html;
}, 2000);
}); });
content_sources = ["computer", "github", "notion"]; content_sources = ["computer", "github", "notion"];
@ -400,26 +403,6 @@
}); });
} }
// Setup the results count slider
const resultsCountSlider = document.getElementById('results-count-slider');
const resultsCountValue = document.getElementById('results-count-value');
// Set the initial value of the slider
resultsCountValue.textContent = resultsCountSlider.value;
// Store the slider value in localStorage when it changes
resultsCountSlider.addEventListener('input', () => {
resultsCountValue.textContent = resultsCountSlider.value;
localStorage.setItem('khojResultsCount', resultsCountSlider.value);
});
// Get the slider value from localStorage on page load
const storedResultsCount = localStorage.getItem('khojResultsCount');
if (storedResultsCount) {
resultsCountSlider.value = storedResultsCount;
resultsCountValue.textContent = storedResultsCount;
}
function generateAPIKey() { function generateAPIKey() {
const apiKeyList = document.getElementById("api-key-list"); const apiKeyList = document.getElementById("api-key-list");
fetch('/auth/token', { fetch('/auth/token', {

View file

@ -7,7 +7,7 @@
<span class="card-title-text">Files</span> <span class="card-title-text">Files</span>
<div class="instructions"> <div class="instructions">
<p class="card-description">Manage files from your computer</p> <p class="card-description">Manage files from your computer</p>
<p id="get-desktop-client" class="card-description">Download the <a href="https://download.khoj.dev">Khoj Desktop app</a> to sync documents from your computer</p> <p id="get-desktop-client" class="card-description">Download the <a href="https://khoj.dev/downloads">Khoj Desktop app</a> to sync documents from your computer</p>
</div> </div>
</h2> </h2>
<div class="section-manage-files"> <div class="section-manage-files">

View file

@ -46,6 +46,9 @@
</div> </div>
</div> </div>
<style> <style>
td {
padding: 10px 0;
}
div.repo { div.repo {
width: 100%; width: 100%;
height: 100%; height: 100%;
@ -124,6 +127,11 @@
return; return;
} }
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
submitButton.innerHTML = "Saving...";
// Save Github config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content-source/github', { fetch('/api/config/data/content-source/github', {
method: 'POST', method: 'POST',
@ -137,15 +145,40 @@
}) })
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github settings.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
return;
});
// Index Github content on server
fetch('/api/update?t=github')
.then(response => response.json())
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => { .then(data => {
if (data["status"] == "ok") { document.getElementById("success").style.display = "none";
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup."; submitButton.innerHTML = "✅ Successfully updated";
document.getElementById("success").style.display = "block"; setTimeout(function() {
} else { submitButton.innerHTML = "Save";
document.getElementById("success").innerHTML = "⚠️ Failed to update settings."; submitButton.disabled = false;
document.getElementById("success").style.display = "block"; }, 2000);
}
}) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Github content.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
});
}); });
</script> </script>
{% endblock %} {% endblock %}

View file

@ -41,6 +41,11 @@
return; return;
} }
const submitButton = document.getElementById("submit");
submitButton.disabled = true;
submitButton.innerHTML = "Saving...";
// Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content-source/notion', { fetch('/api/config/data/content-source/notion', {
method: 'POST', method: 'POST',
@ -53,15 +58,39 @@
}) })
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion settings.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save settings";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
return;
});
// Index Notion content on server
fetch('/api/update?t=notion')
.then(response => response.json())
.then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => { .then(data => {
if (data["status"] == "ok") { document.getElementById("success").style.display = "none";
document.getElementById("success").innerHTML = "✅ Successfully updated. Go to your <a href='/config'>settings page</a> to complete setup."; submitButton.innerHTML = "✅ Successfully updated";
document.getElementById("success").style.display = "block"; setTimeout(function() {
} else { submitButton.innerHTML = "Save";
document.getElementById("success").innerHTML = "⚠️ Failed to update settings."; submitButton.disabled = false;
document.getElementById("success").style.display = "block"; }, 2000);
}
}) })
.catch(error => {
document.getElementById("success").innerHTML = "⚠️ Failed to save Notion content.";
document.getElementById("success").style.display = "block";
submitButton.innerHTML = "⚠️ Failed to save content";
setTimeout(function() {
submitButton.innerHTML = "Save";
submitButton.disabled = false;
}, 2000);
});
}); });
</script> </script>
{% endblock %} {% endblock %}

View file

@ -189,7 +189,6 @@
}) })
.then(response => response.json()) .then(response => response.json())
.then(data => { .then(data => {
console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type); document.getElementById("results").innerHTML = render_results(data, query, type);
}); });
} }

View file

@ -56,6 +56,7 @@ locale.setlocale(locale.LC_ALL, "")
from khoj.configure import configure_routes, initialize_server, configure_middleware from khoj.configure import configure_routes, initialize_server, configure_middleware
from khoj.utils import state from khoj.utils import state
from khoj.utils.cli import cli from khoj.utils.cli import cli
from khoj.utils.initialization import initialization
# Setup Logger # Setup Logger
rich_handler = RichHandler(rich_tracebacks=True) rich_handler = RichHandler(rich_tracebacks=True)
@ -74,8 +75,7 @@ def run(should_start_server=True):
args = cli(state.cli_args) args = cli(state.cli_args)
set_state(args) set_state(args)
# Create app directory, if it doesn't exist logger.info(f"🚒 Initializing Khoj v{state.khoj_version}")
state.config_file.parent.mkdir(parents=True, exist_ok=True)
# Set Logging Level # Set Logging Level
if args.verbose == 0: if args.verbose == 0:
@ -83,6 +83,11 @@ def run(should_start_server=True):
elif args.verbose >= 1: elif args.verbose >= 1:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
initialization()
# Create app directory, if it doesn't exist
state.config_file.parent.mkdir(parents=True, exist_ok=True)
# Set Log File # Set Log File
fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8") fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
fh.setLevel(logging.DEBUG) fh.setLevel(logging.DEBUG)
@ -97,7 +102,7 @@ def run(should_start_server=True):
configure_routes(app) configure_routes(app)
# Mount Django and Static Files # Mount Django and Static Files
app.mount("/django", django_app, name="django") app.mount("/server", django_app, name="server")
static_dir = "static" static_dir = "static"
if not os.path.exists(static_dir): if not os.path.exists(static_dir):
os.mkdir(static_dir) os.mkdir(static_dir)

View file

@ -55,10 +55,10 @@ def extract_questions_offline(
last_year = datetime.now().year - 1 last_year = datetime.now().year - 1
last_christmas_date = f"{last_year}-12-25" last_christmas_date = f"{last_year}-12-25"
next_christmas_date = f"{datetime.now().year}-12-25" next_christmas_date = f"{datetime.now().year}-12-25"
system_prompt = prompts.extract_questions_system_prompt_llamav2.format( system_prompt = prompts.system_prompt_extract_questions_gpt4all.format(
message=(prompts.system_prompt_message_extract_questions_llamav2) message=(prompts.system_prompt_message_extract_questions_gpt4all)
) )
example_questions = prompts.extract_questions_llamav2_sample.format( example_questions = prompts.extract_questions_gpt4all_sample.format(
query=text, query=text,
chat_history=chat_history, chat_history=chat_history,
current_date=current_date, current_date=current_date,
@ -150,14 +150,14 @@ def converse_offline(
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message): elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query conversation_primer = user_query
else: else:
conversation_primer = prompts.notes_conversation_llamav2.format( conversation_primer = prompts.notes_conversation_gpt4all.format(
query=user_query, references=compiled_references_message query=user_query, references=compiled_references_message
) )
# Setup Prompt with Primer or Conversation History # Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(
conversation_primer, conversation_primer,
prompts.system_prompt_message_llamav2, prompts.system_prompt_message_gpt4all,
conversation_log, conversation_log,
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
@ -183,16 +183,16 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
conversation_history = messages[1:-1] conversation_history = messages[1:-1]
formatted_messages = [ formatted_messages = [
prompts.chat_history_llamav2_from_assistant.format(message=message.content) prompts.khoj_message_gpt4all.format(message=message.content)
if message.role == "assistant" if message.role == "assistant"
else prompts.chat_history_llamav2_from_user.format(message=message.content) else prompts.user_message_gpt4all.format(message=message.content)
for message in conversation_history for message in conversation_history
] ]
stop_words = ["<s>"] stop_words = ["<s>"]
chat_history = "".join(formatted_messages) chat_history = "".join(formatted_messages)
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content) templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content) templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message prompted_message = templated_system_message + chat_history + templated_user_message
state.chat_lock.acquire() state.chat_lock.acquire()

View file

@ -20,27 +20,6 @@ from khoj.utils.helpers import ConversationCommand, is_none_or_empty
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
"""
Summarize conversation session using the specified OpenAI chat model
"""
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
# Get Response from GPT
logger.debug(f"Prompt for GPT: {messages}")
response = completion_with_backoff(
messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
model_kwargs={"stop": ['"""'], "frequency_penalty": 0.2},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
return str(response.content).replace("\n\n", "")
def extract_questions( def extract_questions(
text, text,
model: Optional[str] = "gpt-4", model: Optional[str] = "gpt-4",
@ -131,16 +110,14 @@ def converse(
completion_func(chat_response=prompts.no_notes_found.format()) completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()]) return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references): elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
conversation_primer = prompts.general_conversation.format(current_date=current_date, query=user_query) conversation_primer = prompts.general_conversation.format(query=user_query)
else: else:
conversation_primer = prompts.notes_conversation.format( conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
current_date=current_date, query=user_query, references=compiled_references
)
# Setup Prompt with Primer or Conversation History # Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(
conversation_primer, conversation_primer,
prompts.personality.format(), prompts.personality.format(current_date=current_date),
conversation_log, conversation_log,
model, model,
max_prompt_size, max_prompt_size,
@ -157,4 +134,5 @@ def converse(
temperature=temperature, temperature=temperature,
openai_api_key=api_key, openai_api_key=api_key,
completion_func=completion_func, completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
) )

View file

@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
reraise=True, reraise=True,
) )
def chat_completion_with_backoff( def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
): ):
g = ThreadedGenerator(compiled_references, completion_func=completion_func) g = ThreadedGenerator(compiled_references, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t.start() t.start()
return g return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None): def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
callback_handler = StreamingChatCallbackHandler(g) callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI( chat = ChatOpenAI(
streaming=True, streaming=True,
@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
model_name=model_name, # type: ignore model_name=model_name, # type: ignore
temperature=temperature, temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
model_kwargs=model_kwargs,
request_timeout=20, request_timeout=20,
max_retries=1, max_retries=1,
client=None, client=None,

View file

@ -4,30 +4,44 @@ from langchain.prompts import PromptTemplate
## Personality ## Personality
## -- ## --
personality = PromptTemplate.from_template("You are Khoj, a smart, inquisitive and helpful personal assistant.") personality = PromptTemplate.from_template(
"""
You are Khoj, a smart, inquisitive and helpful personal assistant.
Use your general knowledge and the past conversation with the user as context to inform your responses.
You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- You cannot set reminders.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
Today is {current_date} in UTC.
""".strip()
)
## General Conversation ## General Conversation
## -- ## --
general_conversation = PromptTemplate.from_template( general_conversation = PromptTemplate.from_template(
""" """
Using your general knowledge and our past conversations as context, answer the following question. {query}
Current Date: {current_date}
Question: {query}
""".strip() """.strip()
) )
no_notes_found = PromptTemplate.from_template( no_notes_found = PromptTemplate.from_template(
""" """
I'm sorry, I couldn't find any relevant notes to respond to your message. I'm sorry, I couldn't find any relevant notes to respond to your message.
""".strip() """.strip()
) )
system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant. ## Conversation Prompts for GPT4All Models
## --
system_prompt_message_gpt4all = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question. Using your general knowledge and our past conversations as context, answer the following question.
If you do not know the answer, say 'I don't know.'""" If you do not know the answer, say 'I don't know.'"""
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective. system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Write the question as if you can search for the answer on the user's personal notes. - Write the question as if you can search for the answer on the user's personal notes.
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?". - Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
- Add as much context from the previous questions and notes as required into your search queries. - Add as much context from the previous questions and notes as required into your search queries.
@ -35,61 +49,47 @@ system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and i
What follow-up questions, if any, will you need to ask to answer the user's question? What follow-up questions, if any, will you need to ask to answer the user's question?
""" """
system_prompt_llamav2 = PromptTemplate.from_template( system_prompt_gpt4all = PromptTemplate.from_template(
""" """
<s>[INST] <<SYS>> <s>[INST] <<SYS>>
{message} {message}
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>""" <</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
) )
extract_questions_system_prompt_llamav2 = PromptTemplate.from_template( system_prompt_extract_questions_gpt4all = PromptTemplate.from_template(
""" """
<s>[INST] <<SYS>> <s>[INST] <<SYS>>
{message} {message}
<</SYS>>[/INST]</s>""" <</SYS>>[/INST]</s>"""
) )
general_conversation_llamav2 = PromptTemplate.from_template( user_message_gpt4all = PromptTemplate.from_template(
"""
<s>[INST] {query} [/INST]
""".strip()
)
chat_history_llamav2_from_user = PromptTemplate.from_template(
""" """
<s>[INST] {message} [/INST] <s>[INST] {message} [/INST]
""".strip() """.strip()
) )
chat_history_llamav2_from_assistant = PromptTemplate.from_template( khoj_message_gpt4all = PromptTemplate.from_template(
""" """
{message}</s> {message}</s>
""".strip() """.strip()
) )
conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST] {query} [/INST]
""".strip()
)
## Notes Conversation ## Notes Conversation
## -- ## --
notes_conversation = PromptTemplate.from_template( notes_conversation = PromptTemplate.from_template(
""" """
Using my personal notes and our past conversations as context, answer the following question. Use my personal notes and our past conversations to inform your response.
Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
These questions should end with a question mark.
Current Date: {current_date}
Notes: Notes:
{references} {references}
Question: {query} Query: {query}
""".strip() """.strip()
) )
notes_conversation_llamav2 = PromptTemplate.from_template( notes_conversation_gpt4all = PromptTemplate.from_template(
""" """
User's Notes: User's Notes:
{references} {references}
@ -98,13 +98,6 @@ Question: {query}
) )
## Summarize Chat
## --
summarize_chat = PromptTemplate.from_template(
f"{personality.format()} Summarize the conversation from your first person perspective"
)
## Summarize Notes ## Summarize Notes
## -- ## --
summarize_notes = PromptTemplate.from_template( summarize_notes = PromptTemplate.from_template(
@ -132,7 +125,10 @@ Question: {user_query}
Answer (in second person):""" Answer (in second person):"""
) )
extract_questions_llamav2_sample = PromptTemplate.from_template(
## Extract Questions
## --
extract_questions_gpt4all_sample = PromptTemplate.from_template(
""" """
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s> <s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
<s>[INST] How was my trip to Cambodia? [/INST] <s>[INST] How was my trip to Cambodia? [/INST]
@ -157,8 +153,6 @@ Use these notes from the user's previous conversations to provide a response:
) )
## Extract Questions
## --
extract_questions = PromptTemplate.from_template( extract_questions = PromptTemplate.from_template(
""" """
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes. You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.

View file

@ -27,5 +27,5 @@ class CrossEncoderModel:
def predict(self, query, hits: List[SearchResponse]): def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits] cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross__inp) cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
return cross_scores return cross_scores

View file

@ -7,7 +7,7 @@ import json
from typing import List, Optional, Union, Any from typing import List, Optional, Union, Any
# External Packages # External Packages
from fastapi import APIRouter, HTTPException, Header, Request from fastapi import APIRouter, Depends, HTTPException, Header, Request
from starlette.authentication import requires from starlette.authentication import requires
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
@ -36,6 +36,7 @@ from khoj.routers.helpers import (
agenerate_chat_response, agenerate_chat_response,
update_telemetry_state, update_telemetry_state,
is_ready_to_chat, is_ready_to_chat,
ApiUserRateLimiter,
) )
from khoj.processor.conversation.prompts import help_message from khoj.processor.conversation.prompts import help_message
from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.gpt import extract_questions
@ -177,11 +178,15 @@ async def set_content_config_github_data(
user = request.user.object user = request.user.object
await adapters.set_user_github_config( try:
user=user, await adapters.set_user_github_config(
pat_token=updated_config.pat_token, user=user,
repos=updated_config.repos, pat_token=updated_config.pat_token,
) repos=updated_config.repos,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -205,10 +210,14 @@ async def set_content_config_notion_data(
user = request.user.object user = request.user.object
await adapters.set_notion_config( try:
user=user, await adapters.set_notion_config(
token=updated_config.token, user=user,
) token=updated_config.token,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state( update_telemetry_state(
request=request, request=request,
@ -348,7 +357,7 @@ async def search(
n: Optional[int] = 5, n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All, t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False, r: Optional[bool] = False,
score_threshold: Optional[Union[float, None]] = None, max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True, dedupe: Optional[bool] = True,
client: Optional[str] = None, client: Optional[str] = None,
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
@ -367,12 +376,12 @@ async def search(
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
results_count = n or 5 results_count = n or 5
score_threshold = score_threshold if score_threshold is not None else -math.inf max_distance = max_distance if max_distance is not None else math.inf
search_futures: List[concurrent.futures.Future] = [] search_futures: List[concurrent.futures.Future] = []
# return cached results, if available # return cached results, if available
if user: if user:
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]: if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache") logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key] return state.query_cache[user.uuid][query_cache_key]
@ -409,8 +418,7 @@ async def search(
user_query, user_query,
t, t,
question_embedding=encoded_asymmetric_query, question_embedding=encoded_asymmetric_query,
rank_results=r or False, max_distance=max_distance,
score_threshold=score_threshold,
) )
] ]
@ -423,7 +431,6 @@ async def search(
results_count, results_count,
state.search_models.image_search, state.search_models.image_search,
state.content_index.image, state.content_index.image,
score_threshold=score_threshold,
) )
] ]
@ -446,11 +453,10 @@ async def search(
# Collate results # Collate results
results += text_search.collate_results(hits, dedupe=dedupe) results += text_search.collate_results(hits, dedupe=dedupe)
if r:
results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
else:
# Sort results across all content types and take top results # Sort results across all content types and take top results
results = sorted(results, key=lambda x: float(x.score))[:results_count] results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
:results_count
]
# Cache results # Cache results
if user: if user:
@ -575,11 +581,14 @@ async def chat(
request: Request, request: Request,
q: str, q: str,
n: Optional[int] = 5, n: Optional[int] = 5,
d: Optional[float] = 0.15,
client: Optional[str] = None, client: Optional[str] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
host: Optional[str] = Header(None), host: Optional[str] = Header(None),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
) -> Response: ) -> Response:
user = request.user.object user = request.user.object
@ -591,7 +600,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), conversation_command request, meta_log, q, (n or 5), (d or math.inf), conversation_command
) )
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references): if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@ -606,7 +615,7 @@ async def chat(
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.
llm_response = await agenerate_chat_response( llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query, defiltered_query,
meta_log, meta_log,
compiled_references, compiled_references,
@ -615,6 +624,19 @@ async def chat(
user, user,
) )
chat_metadata.update({"conversation_command": conversation_command.value})
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=chat_metadata,
)
if llm_response is None: if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500) return Response(content=llm_response, media_type="text/plain", status_code=500)
@ -634,16 +656,6 @@ async def chat(
response_obj = {"response": actual_response, "context": compiled_references} response_obj = {"response": actual_response, "context": compiled_references}
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
)
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200) return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
@ -652,6 +664,7 @@ async def extract_references_and_questions(
meta_log: dict, meta_log: dict,
q: str, q: str,
n: int, n: int,
d: float,
conversation_type: ConversationCommand = ConversationCommand.Default, conversation_type: ConversationCommand = ConversationCommand.Default,
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -663,7 +676,7 @@ async def extract_references_and_questions(
if conversation_type == ConversationCommand.General: if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q return compiled_references, inferred_queries, q
if not sync_to_async(EntryAdapters.user_has_entries)(user=user): if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.warning( logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes." "No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
) )
@ -712,7 +725,7 @@ async def extract_references_and_questions(
request=request, request=request,
n=n_items, n=n_items,
r=True, r=True,
score_threshold=-5.0, max_distance=d,
dedupe=False, dedupe=False,
) )
) )

View file

@ -16,7 +16,7 @@ from google.auth.transport import requests as google_requests
# Internal Packages # Internal Packages
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
from database.models import KhojApiUser from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state from khoj.utils import state
@ -100,6 +100,16 @@ async def auth(request: Request):
if khoj_user: if khoj_user:
request.session["user"] = dict(idinfo) request.session["user"] = dict(idinfo)
if not khoj_user.last_login:
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_user",
metadata={"user_id": str(khoj_user.uuid)},
)
logger.log(logging.INFO, f"New User Created: {khoj_user.uuid}")
RedirectResponse(url="/?status=welcome")
return RedirectResponse(url="/") return RedirectResponse(url="/")

View file

@ -1,21 +1,27 @@
import logging # Standard Packages
import asyncio import asyncio
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Iterator, List, Optional, Union import logging
from concurrent.futures import ThreadPoolExecutor from time import time
from typing import Iterator, List, Optional, Union, Tuple, Dict
# External Packages
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
# Internal Packages
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser from database.models import KhojUser, Subscription
from database.adapters import ConversationAdapters from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1) executor = ThreadPoolExecutor(max_workers=1)
@ -61,12 +67,15 @@ def update_telemetry_state(
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
): ):
user: KhojUser = request.user.object if request.user.is_authenticated else None user: KhojUser = request.user.object if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and user.subscription else None
user_state = { user_state = {
"client_host": request.client.host if request.client else None, "client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown", "user_agent": user_agent or "unknown",
"referer": referer or "unknown", "referer": referer or "unknown",
"host": host or "unknown", "host": host or "unknown",
"server_id": str(user.uuid) if user else None, "server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None,
} }
if metadata: if metadata:
@ -109,7 +118,7 @@ def generate_chat_response(
inferred_queries: List[str] = [], inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default, conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None, user: KhojUser = None,
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
def _save_to_conversation_log( def _save_to_conversation_log(
q: str, q: str,
chat_response: str, chat_response: str,
@ -132,6 +141,8 @@ def generate_chat_response(
chat_response = None chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}") logger.debug(f"Conversation Type: {conversation_command.name}")
metadata = {}
try: try:
partial_completion = partial( partial_completion = partial(
_save_to_conversation_log, _save_to_conversation_log,
@ -148,8 +159,8 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_default_conversation_config() conversation_config = ConversationAdapters.get_default_conversation_config()
openai_chat_config = ConversationAdapters.get_openai_conversation_config() openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline": if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.gpt4all_processor_config.loaded_model is None: if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model) state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline( chat_response = converse_offline(
@ -179,8 +190,33 @@ def generate_chat_response(
tokenizer_name=conversation_config.tokenizer, tokenizer_name=conversation_config.tokenizer,
) )
metadata.update({"chat_model": conversation_config.chat_model})
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return chat_response return chat_response, metadata
class ApiUserRateLimiter:
def __init__(self, requests: int, window: int):
self.requests = requests
self.window = window
self.cache: dict[str, list[float]] = defaultdict(list)
def __call__(self, request: Request):
user: KhojUser = request.user.object
user_requests = self.cache[user.uuid]
# Remove requests outside of the time window
cutoff = time() - self.window
while user_requests and user_requests[0] < cutoff:
user_requests.pop(0)
# Check if the user has exceeded the rate limit
if len(user_requests) >= self.requests:
raise HTTPException(status_code=429, detail="Too Many Requests")
# Add the current request to the cache
user_requests.append(time())

View file

@ -146,7 +146,7 @@ def extract_metadata(image_name):
async def query( async def query(
raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
): ):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
@ -167,7 +167,8 @@ async def query(
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings. # Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger): with timer("Search Time", logger):
image_hits = { image_hits = {
result["corpus_id"]: {"image_score": result["score"], "score": result["score"]} # Map scores to distance metric by multiplying by -1
result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0] for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
} }
@ -204,7 +205,7 @@ async def query(
] ]
# Filter results by score threshold # Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] >= score_threshold] hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores # Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True) return sorted(hits, key=lambda hit: hit["score"], reverse=True)

View file

@ -104,8 +104,7 @@ async def query(
raw_query: str, raw_query: str,
type: SearchType = SearchType.All, type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None, question_embedding: Union[torch.Tensor, None] = None,
rank_results: bool = False, max_distance: float = math.inf,
score_threshold: float = -math.inf,
) -> Tuple[List[dict], List[Entry]]: ) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
@ -127,6 +126,7 @@ async def query(
max_results=top_k, max_results=top_k,
file_type_filter=file_type, file_type_filter=file_type,
raw_query=raw_query, raw_query=raw_query,
max_distance=max_distance,
).all() ).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg] hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
@ -177,12 +177,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
) )
def rerank_and_sort_results(hits, query): def rerank_and_sort_results(hits, query, rank_results):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits) if rank_results:
hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score # Sort results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results=True, hits=hits) hits = sort_results(rank_results=rank_results, hits=hits)
return hits return hits
@ -217,9 +221,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
with timer("Cross-Encoder Predict Time", logger, state.device): with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model.predict(query, hits) cross_scores = state.cross_encoder_model.predict(query, hits)
# Store cross-encoder scores in results dictionary for ranking # Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
hits[idx]["cross_score"] = cross_scores[idx] hits[idx]["cross_score"] = 1 - cross_scores[idx]
return hits return hits
@ -227,7 +231,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]: def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
"""Order results by cross-encoder score followed by bi-encoder score""" """Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device): with timer("Rank Time", logger, state.device):
hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
if rank_results: if rank_results:
hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
return hits return hits

View file

@ -6,6 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env" app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/" content_directory = "~/.khoj/content/"
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
empty_config = { empty_config = {
"search-type": { "search-type": {

View file

@ -0,0 +1,98 @@
import logging
import os
from database.models import (
KhojUser,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
ChatModelOptions,
)
from khoj.utils.constants import default_offline_chat_model
from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__)
def initialization():
def _create_admin_user():
logger.info(
"👩‍✈️ Setting up admin user. These credentials will allow you to configure your server at /server/admin."
)
email_addr = os.getenv("KHOJ_ADMIN_EMAIL") or input("Email: ")
password = os.getenv("KHOJ_ADMIN_PASSWORD") or input("Password: ")
admin_user = KhojUser.objects.create_superuser(email=email_addr, username=email_addr, password=password)
logger.info(f"👩‍✈️ Created admin user: {admin_user.email}")
def _create_chat_configuration():
logger.info(
"🗣️ Configure chat models available to your server. You can always update these at /server/admin using the credentials of your admin account"
)
try:
# Some environments don't support interactive input. We catch the exception and return if that's the case. The admin can still configure their settings from the admin page.
input()
except EOFError:
return
try:
# Note: gpt4all package is not available on all devices.
# So ensure gpt4all package is installed before continuing this step.
import gpt4all
use_offline_model = input("Use offline chat model? (y/n): ")
if use_offline_model == "y":
logger.info("🗣️ Setting up offline chat model")
OfflineChatProcessorConversationConfig.objects.create(enabled=True)
offline_chat_model = input(
f"Enter the name of the offline chat model you want to use, based on the models in HuggingFace (press enter to use the default: {default_offline_chat_model}): "
)
if offline_chat_model == "":
ChatModelOptions.objects.create(
chat_model=default_offline_chat_model, model_type=ChatModelOptions.ModelType.OFFLINE
)
else:
max_tokens = input("Enter the maximum number of tokens to use for the offline chat model:")
tokenizer = input("Enter the tokenizer to use for the offline chat model:")
ChatModelOptions.objects.create(
chat_model=offline_chat_model,
model_type=ChatModelOptions.ModelType.OFFLINE,
max_prompt_size=max_tokens,
tokenizer=tokenizer,
)
except ModuleNotFoundError as e:
logger.warning("Offline models are not supported on this device.")
use_openai_model = input("Use OpenAI chat model? (y/n): ")
if use_openai_model == "y":
logger.info("🗣️ Setting up OpenAI chat model")
api_key = input("Enter your OpenAI API key: ")
OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
openai_chat_model = input("Enter the name of the OpenAI chat model you want to use: ")
max_tokens = input("Enter the maximum number of tokens to use for the OpenAI chat model:")
ChatModelOptions.objects.create(
chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_tokens=max_tokens
)
logger.info("🗣️ Chat model configuration complete")
admin_user = KhojUser.objects.filter(is_staff=True).first()
if admin_user is None:
while True:
try:
_create_admin_user()
break
except Exception as e:
logger.error(f"🚨 Failed to create admin user: {e}", exc_info=True)
chat_config = ConversationAdapters.get_default_conversation_config()
if admin_user is None and chat_config is None:
while True:
try:
_create_chat_configuration()
break
except Exception as e:
logger.error(f"🚨 Failed to create chat configuration: {e}", exc_info=True)

View file

@ -43,6 +43,7 @@ from tests.helpers import (
OpenAIProcessorConversationConfigFactory, OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory, OfflineChatProcessorConversationConfigFactory,
UserConversationProcessorConfigFactory, UserConversationProcessorConfigFactory,
SubscriptionFactory,
) )
@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
@pytest.mark.django_db @pytest.mark.django_db
@pytest.fixture @pytest.fixture
def default_user(): def default_user():
return UserFactory() user = UserFactory()
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db @pytest.mark.django_db
@ -78,11 +81,31 @@ def default_user2():
if KhojUser.objects.filter(username="default").exists(): if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default") return KhojUser.objects.get(username="default")
return KhojUser.objects.create( user = KhojUser.objects.create(
username="default", username="default",
email="default@example.com", email="default@example.com",
password="default", password="default",
) )
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db
@pytest.fixture
def default_user3():
"""
This user should not have any data associated with it
"""
if KhojUser.objects.filter(username="default3").exists():
return KhojUser.objects.get(username="default3")
user = KhojUser.objects.create(
username="default3",
email="default3@example.com",
password="default3",
)
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db @pytest.mark.django_db
@ -111,6 +134,19 @@ def api_user2(default_user2):
) )
@pytest.mark.django_db
@pytest.fixture
def api_user3(default_user3):
if KhojApiUser.objects.filter(user=default_user3).exists():
return KhojApiUser.objects.get(user=default_user3)
return KhojApiUser.objects.create(
user=default_user3,
name="api-key",
token="kk-diff-secret-3",
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def search_models(search_config: SearchConfig): def search_models(search_config: SearchConfig):
search_models = SearchModels() search_models = SearchModels()
@ -206,7 +242,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
OpenAIProcessorConversationConfigFactory() OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model) UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
state.anonymous_mode = False state.anonymous_mode = True
app = FastAPI() app = FastAPI()
@ -224,7 +260,9 @@ def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUs
# Initialize Processor from Config # Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory() OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
state.anonymous_mode = True state.anonymous_mode = True

View file

@ -9,6 +9,7 @@ from database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
UserConversationConfig, UserConversationConfig,
Conversation, Conversation,
Subscription,
) )
@ -68,3 +69,13 @@ class ConversationFactory(factory.django.DjangoModelFactory):
model = Conversation model = Conversation
user = factory.SubFactory(UserFactory) user = factory.SubFactory(UserFactory)
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = "standard"
is_recurring = False
renewal_date = "2100-04-01"

View file

@ -16,7 +16,7 @@ from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries from khoj.processor.org_mode.org_to_entries import OrgToEntries
from database.models import KhojUser from database.models import KhojUser, KhojApiUser
from database.adapters import EntryAdapters from database.adapters import EntryAdapters
@ -351,6 +351,24 @@ def test_different_user_data_not_accessed(client, sample_org_data, default_user:
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden" assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiUser):
# Arrange
token = api_user3.token
headers = {"Authorization": "Bearer " + token}
user_query = quote("How to git install application?")
# Act
response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
# Assert
assert response.status_code == 200
# assert actual response has no data as the default_user3, though other users have data
assert len(response.json()) == 0
assert response.json() == []
def get_sample_files_data(): def get_sample_files_data():
return [ return [
("files", ("path/to/filename.org", "* practicing piano", "text/org")), ("files", ("path/to/filename.org", "* practicing piano", "text/org")),

View file

@ -307,6 +307,8 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
"which one is", "which one is",
"which of namita's sons", "which of namita's sons",
"the birth order", "the birth order",
"provide more context",
"provide me with more context",
] ]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (