Add additional personalization in Chat via Location, Username (#644)

* Add location metadata to chat history
* Add support for custom configuration of the user name
* Add region, country, city in the desktop app's URL for context in chat
* Update prompts to specify user location, rather than just location.
* Add location data to Obsidian chat query
* Use first word for first name, last word for last name when setting profile name
This commit is contained in:
sabaimran 2024-02-13 03:35:13 -08:00 committed by GitHub
parent a3eb17b7d4
commit 32ec54172e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 286 additions and 22 deletions

View file

@ -31,6 +31,22 @@
});
}
let region = null;
let city = null;
let countryName = null;
fetch("https://ipapi.co/json")
.then(response => response.json())
.then(data => {
region = data.region;
city = data.city;
countryName = data.country_name;
})
.catch(err => {
console.log(err);
return;
});
function formatDate(date) {
// Format date in HH:MM, DD MMM YYYY format
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
@ -337,7 +353,7 @@
// Generate backend API URL to execute query
let url = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}`;
let url = `${hostURL}/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}`;
const khojToken = await window.tokenAPI.getToken();
const headers = { 'Authorization': `Bearer ${khojToken}` };

View file

@ -10,6 +10,9 @@ export interface ChatJsonResult {
export class KhojChatModal extends Modal {
result: string;
setting: KhojSetting;
region: string;
city: string;
countryName: string;
constructor(app: App, setting: KhojSetting) {
super(app);
@ -17,6 +20,19 @@ export class KhojChatModal extends Modal {
// Register Modal Keybindings to send user message
this.scope.register([], 'Enter', async () => { await this.chat() });
fetch("https://ipapi.co/json")
.then(response => response.json())
.then(data => {
this.region = data.region;
this.city = data.city;
this.countryName = data.country_name;
})
.catch(err => {
console.log(err);
return;
});
}
async chat() {
@ -354,7 +370,7 @@ export class KhojChatModal extends Modal {
// Get chat response from Khoj backend
let encodedQuery = encodeURIComponent(query);
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`;
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true&region=${this.region}&city=${this.city}&country=${this.countryName}`;
let responseElement = this.createKhojResponseDiv();
// Temporary status message to indicate that Khoj is thinking

View file

@ -173,6 +173,24 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
return user
def set_user_name(user: KhojUser, first_name: str, last_name: str) -> KhojUser:
user.first_name = first_name
user.last_name = last_name
user.save()
return user
def get_user_name(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
return full_name
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile:
return google_profile.given_name
return None
def get_user_subscription(email: str) -> Optional[Subscription]:
return Subscription.objects.filter(user__email=email).first()
@ -291,6 +309,17 @@ def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
async def aget_user_name(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
return full_name
google_profile: GoogleUser = await GoogleUser.objects.filter(user=user).afirst()
if google_profile:
return google_profile.given_name
return None
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None

View file

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?><!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3.95442 10.166C4.04608 9.76202 3.79293 9.36025 3.38898 9.26859C2.98504 9.17693 2.58327 9.43009 2.49161 9.83403L3.95442 10.166ZM5.49981 4.73283C5.19117 5.00907 5.1649 5.48322 5.44115 5.79187C5.71739 6.10051 6.19154 6.12678 6.50019 5.85053L5.49981 4.73283ZM15 14.25C14.5858 14.25 14.25 14.5858 14.25 15C14.25 15.4142 14.5858 15.75 15 15.75L15 14.25ZM17.25 18.7083C17.25 19.1225 17.5858 19.4583 18 19.4583C18.4142 19.4583 18.75 19.1225 18.75 18.7083H17.25ZM5.25 18.7083C5.25 19.1225 5.58579 19.4583 6 19.4583C6.41421 19.4583 6.75 19.1225 6.75 18.7083H5.25ZM9 15L8.99998 15.75H9V15ZM11 15.75C11.4142 15.75 11.75 15.4142 11.75 15C11.75 14.5858 11.4142 14.25 11 14.25V15.75ZM12 3.75C16.5563 3.75 20.25 7.44365 20.25 12H21.75C21.75 6.61522 17.3848 2.25 12 2.25V3.75ZM12 20.25C7.44365 20.25 3.75 16.5563 3.75 12H2.25C2.25 17.3848 6.61522 21.75 12 21.75V20.25ZM20.25 12C20.25 16.5563 16.5563 20.25 12 20.25V21.75C17.3848 21.75 21.75 17.3848 21.75 12H20.25ZM3.75 12C3.75 11.3688 3.82074 10.7551 3.95442 10.166L2.49161 9.83403C2.33338 10.5313 2.25 11.2564 2.25 12H3.75ZM6.50019 5.85053C7.96026 4.54373 9.88655 3.75 12 3.75V2.25C9.50333 2.25 7.22428 3.1894 5.49981 4.73283L6.50019 5.85053ZM14.25 9C14.25 10.2426 13.2426 11.25 12 11.25V12.75C14.0711 12.75 15.75 11.0711 15.75 9H14.25ZM12 11.25C10.7574 11.25 9.75 10.2426 9.75 9H8.25C8.25 11.0711 9.92893 12.75 12 12.75V11.25ZM9.75 9C9.75 7.75736 10.7574 6.75 12 6.75V5.25C9.92893 5.25 8.25 6.92893 8.25 9H9.75ZM12 6.75C13.2426 6.75 14.25 7.75736 14.25 9H15.75C15.75 6.92893 14.0711 5.25 12 5.25V6.75ZM15 15.75C15.6008 15.75 16.1482 16.0891 16.5769 16.6848C17.0089 17.2852 17.25 18.0598 17.25 18.7083H18.75C18.75 17.7371 18.4052 16.6575 17.7944 15.8086C17.1801 14.9551 16.2275 14.25 15 14.25L15 15.75ZM6.75 18.7083C6.75 18.0598 6.99109 17.2852 7.42315 16.6848C7.85183 16.0891 8.39919 15.75 8.99998 15.75L9.00002 14.25C7.77253 14.25 6.81989 14.9551 6.20564 15.8086C5.59477 16.6575 5.25 17.7371 5.25 18.7083H6.75ZM9 15.75H11V14.25H9V15.75Z" fill="#000000"/>
</svg>

After

Width:  |  Height:  |  Size: 2.2 KiB

View file

@ -103,7 +103,7 @@
.section-title {
margin: 0;
padding: 0 0 16px 0;
padding: 12px 0 16px 0;
font-size: 32;
font-weight: normal;
}
@ -361,6 +361,11 @@
margin-right: 8px;
}
input#profile_given_name {
width: 100%;
margin: 0;
}
@media screen and (max-width: 700px) {
.section-cards {
grid-template-columns: 1fr;

View file

@ -43,6 +43,22 @@ To get started, just start typing below. You can also type / to see a list of co
});
}
let region = null;
let city = null;
let countryName = null;
fetch("https://ipapi.co/json")
.then(response => response.json())
.then(data => {
region = data.region;
city = data.city;
countryName = data.country_name;
})
.catch(err => {
console.log(err);
return;
});
function formatDate(date) {
// Format date in HH:MM, DD MMM YYYY format
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
@ -345,7 +361,7 @@ To get started, just start typing below. You can also type / to see a list of co
}
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}`;
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}&region=${region}&city=${city}&country=${countryName}`;
let new_response = document.createElement("div");
new_response.classList.add("chat-message", "khoj");

View file

@ -3,6 +3,25 @@
<div class="page">
<div id="content" class="section">
<h2 class="section-title">Profile</h2>
<div class="section-cards">
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/user-silhouette.svg" alt="Profile Name">
<h3 class="card-title">
Name
</h3>
</div>
<div class="card-description-row">
<input type="text" id="profile_given_name" class="form-control" placeholder="Enter your name here" value="{{ given_name }}">
</div>
<div class="card-action-row">
<button id="save-model" class="card-button happy" onclick="saveProfileGivenName()">
Save
</button>
</div>
</div>
</div>
<h2 class="section-title">Content</h2>
<button id="compute-index-size" class="card-button" onclick="getIndexedDataSize()">
Data Usage
@ -297,6 +316,27 @@
</div>
<script>
function saveProfileGivenName() {
const givenName = document.getElementById("profile_given_name").value;
fetch('/api/config/user/name?name=' + givenName, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
}
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Profile name has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
}
})
}
function updateChatModel() {
const chatModel = document.getElementById("chat-models").value;
const saveModelButton = document.getElementById("save-model");
@ -314,6 +354,14 @@
if (data.status == "ok") {
saveModelButton.innerHTML = "Save";
saveModelButton.disabled = false;
let notificationBanner = document.getElementById("notification-banner");
notificationBanner.innerHTML = "Conversation model has been updated!";
notificationBanner.style.display = "block";
setTimeout(function() {
notificationBanner.style.display = "none";
}, 5000);
} else {
saveModelButton.innerHTML = "Error";
saveModelButton.disabled = false;

View file

@ -13,6 +13,7 @@ from khoj.processor.conversation.utils import (
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -24,6 +25,7 @@ def extract_questions_offline(
conversation_log={},
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@ -45,6 +47,8 @@ def extract_questions_offline(
gpt4all_model = loaded_model or GPT4All(model)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = ""
@ -68,6 +72,7 @@ def extract_questions_offline(
last_year=last_year,
last_christmas_date=last_christmas_date,
next_christmas_date=next_christmas_date,
location=location,
)
message = system_prompt + example_questions
state.chat_lock.acquire()
@ -133,6 +138,8 @@ def converse_offline(
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@ -150,6 +157,15 @@ def converse_offline(
conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
conversation_primer = f"{location_prompt}\n{conversation_primer}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
conversation_primer = f"{user_name_prompt}\n{conversation_primer}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references_message):
return iter([prompts.no_notes_found.format()])

View file

@ -1,4 +1,3 @@
import json
import logging
from datetime import datetime, timedelta
from typing import Optional
@ -13,6 +12,7 @@ from khoj.processor.conversation.openai.utils import (
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -24,6 +24,7 @@ def extract_questions(
api_key=None,
temperature=0,
max_tokens=100,
location_data: LocationData = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -32,6 +33,8 @@ def extract_questions(
def _valid_question(question: str):
return not is_none_or_empty(question) and question != "[]"
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
@ -56,6 +59,7 @@ def extract_questions(
chat_history=chat_history,
text=text,
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
)
messages = [ChatMessage(content=prompt, role="assistant")]
@ -125,6 +129,8 @@ def converse(
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
):
"""
Converse with user using OpenAI's ChatGPT
@ -135,6 +141,15 @@ def converse(
conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
conversation_primer = f"{location_prompt}\n{conversation_primer}"
if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
conversation_primer = f"{user_name_prompt}\n{conversation_primer}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())

View file

@ -122,6 +122,9 @@ image_generation_improve_prompt = PromptTemplate.from_template(
"""
You are a talented creator. Generate a detailed prompt to generate an image based on the following description. Update the query below to improve the image generation. Add additional context to the query to improve the image generation. Make sure to retain any important information from the query. Use the conversation log to inform your response.
Today's Date: {current_date}
User's Location: {location}
Conversation Log:
{chat_history}
@ -181,7 +184,7 @@ Answer (in second person):"""
## --
extract_questions_gpt4all_sample = PromptTemplate.from_template(
"""
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
<s>[INST] <<SYS>>Current Date: {current_date}. User's Location: {location}<</SYS>> [/INST]</s>
<s>[INST] How was my trip to Cambodia? [/INST]
How was my trip to Cambodia?</s>
<s>[INST] Who did I visit the temple with on that trip? [/INST]
@ -215,6 +218,7 @@ You are Khoj, an extremely smart and helpful search assistant with the ability t
What searches, if any, will you need to perform to answer the users question?
Provide search queries as a JSON list of strings
Current Date: {current_date}
User's Location: {location}
Q: How was my trip to Cambodia?
@ -355,6 +359,7 @@ You are Khoj, an extremely smart and helpful search assistant. You are tasked wi
What Google searches, if any, will you need to perform to answer the user's question?
Provide search queries as a list of strings
Current Date: {current_date}
User's Location: {location}
Here are some examples:
History:
@ -431,3 +436,17 @@ You are using the **{model}** model on the **{device}**.
**version**: {version}
""".strip()
)
# Personalization to the user
# --
user_location = PromptTemplate.from_template(
"""
User's Location: {location}
""".strip()
)
user_name = PromptTemplate.from_template(
"""
User's Name: {name}
""".strip()
)

View file

@ -7,6 +7,7 @@ import requests
from khoj.routers.helpers import extract_relevant_info, generate_online_subqueries
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -33,7 +34,7 @@ OLOSTEP_QUERY_PARAMS = {
}
async def search_with_google(query: str, conversation_history: dict):
async def search_with_google(query: str, conversation_history: dict, location: LocationData):
def _search_with_google(subquery: str):
payload = json.dumps(
{
@ -62,7 +63,7 @@ async def search_with_google(query: str, conversation_history: dict):
raise ValueError("SERPER_DEV_API_KEY is not set")
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(query, conversation_history)
subqueries = await generate_online_subqueries(query, conversation_history, location)
response_dict = {}

View file

@ -37,7 +37,7 @@ from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, timer
from khoj.utils.rawconfig import SearchResponse
from khoj.utils.rawconfig import LocationData, SearchResponse
from khoj.utils.state import SearchType
# Initialize Router
@ -275,6 +275,7 @@ async def extract_references_and_questions(
n: int,
d: float,
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
):
user = request.user.object if request.user.is_authenticated else None
@ -320,7 +321,11 @@ async def extract_references_and_questions(
loaded_model = state.gpt4all_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
defiltered_query,
loaded_model=loaded_model,
conversation_log=meta_log,
should_extract_questions=False,
location_data=location_data,
)
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
@ -328,7 +333,11 @@ async def extract_references_and_questions(
api_key = openai_chat_config.api_key
chat_model = default_openai_llm.chat_model
inferred_queries = extract_questions(
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
defiltered_query,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
)
# Collate search results as context for GPT

View file

@ -10,7 +10,7 @@ from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_user_name
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
@ -36,6 +36,7 @@ from khoj.utils.helpers import (
get_device,
is_none_or_empty,
)
from khoj.utils.rawconfig import LocationData
# Initialize Router
logger = logging.getLogger(__name__)
@ -219,6 +220,9 @@ async def chat(
stream: Optional[bool] = False,
slug: Optional[str] = None,
conversation_id: Optional[int] = None,
city: Optional[str] = None,
region: Optional[str] = None,
country: Optional[str] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
@ -251,8 +255,15 @@ async def chat(
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
location = None
if city or region or country:
location = LocationData(city=city, region=region, country=country)
user_name = await aget_user_name(user)
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
)
online_results: Dict = dict()
@ -269,7 +280,7 @@ async def chat(
if ConversationCommand.Online in conversation_commands:
try:
online_results = await search_with_google(defiltered_query, meta_log)
online_results = await search_with_google(defiltered_query, meta_log, location)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
@ -284,7 +295,7 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log, location_data=location)
if image is None:
content_obj = {
"image": image,
@ -316,6 +327,8 @@ async def chat(
user,
request.user.client_app,
conversation_id,
location,
user_name,
)
chat_metadata.update({"conversation_command": ",".join([cmd.value for cmd in conversation_commands])})

View file

@ -290,6 +290,38 @@ async def get_indexed_data_size(request: Request, common: CommonQueryParams):
)
@api_config.post("/user/name", status_code=200)
@requires(["authenticated"])
def set_user_name(
request: Request,
name: str,
client: Optional[str] = None,
):
user = request.user.object
split_name = name.split(" ")
if len(split_name) > 2:
raise HTTPException(status_code=400, detail="Name must be in the format: Firstname Lastname")
if len(split_name) == 1:
first_name = split_name[0]
last_name = ""
else:
first_name, last_name = split_name[0], split_name[-1]
adapters.set_user_name(user, first_name, last_name)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_user_name",
client=client,
)
return {"status": "ok"}
@api_config.get("/types", response_model=List[str])
@requires(["authenticated"])
def get_config_types(

View file

@ -1,7 +1,6 @@
import asyncio
import json
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from functools import partial
@ -9,6 +8,7 @@ from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
import openai
import requests
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope
@ -39,6 +39,7 @@ from khoj.utils.helpers import (
log_telemetry,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -180,10 +181,11 @@ async def aget_relevant_information_sources(query: str, conversation_history: di
return [ConversationCommand.Default]
async def generate_online_subqueries(q: str, conversation_history: dict) -> List[str]:
async def generate_online_subqueries(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
"""
Generate subqueries from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
@ -191,6 +193,7 @@ async def generate_online_subqueries(q: str, conversation_history: dict) -> List
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
)
response = await send_message_to_model_wrapper(online_queries_prompt)
@ -227,14 +230,19 @@ async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
return response.strip()
async def generate_better_image_prompt(q: str, conversation_history: str) -> str:
async def generate_better_image_prompt(q: str, conversation_history: str, location_data: LocationData) -> str:
"""
Generate a better image prompt from the given query
"""
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d")
image_prompt = prompts.image_generation_improve_prompt.format(
query=q,
chat_history=conversation_history,
location=location,
current_date=today_date,
)
response = await send_message_to_model_wrapper(image_prompt)
@ -293,6 +301,8 @@ def generate_chat_response(
user: KhojUser = None,
client_application: ClientApplication = None,
conversation_id: int = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -330,6 +340,8 @@ def generate_chat_response(
model=conversation_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
)
elif conversation_config.model_type == "openai":
@ -347,6 +359,8 @@ def generate_chat_response(
conversation_commands=conversation_commands,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
)
metadata.update({"chat_model": conversation_config.chat_model})
@ -358,7 +372,9 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[str], int, Optional[str]]:
async def text_to_image(
message: str, conversation_log: dict, location_data: LocationData
) -> Tuple[Optional[str], int, Optional[str]]:
status_code = 200
image = None
@ -373,7 +389,7 @@ async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
improved_image_prompt = await generate_better_image_prompt(message, chat_history)
improved_image_prompt = await generate_better_image_prompt(message, chat_history, location_data=location_data)
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"

View file

@ -13,6 +13,7 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
get_user_github_config,
get_user_name,
get_user_notion_config,
get_user_subscription_state,
)
@ -138,6 +139,7 @@ def config_page(request: Request):
if user_subscription and user_subscription.renewal_date
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
)
given_name = get_user_name(user)
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
successfully_configured = {
@ -166,6 +168,7 @@ def config_page(request: Request):
"current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
"username": user.username,
"given_name": given_name,
"conversation_options": all_conversation_options,
"search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id,

View file

@ -283,8 +283,8 @@ command_descriptions = {
}
tool_descriptions_for_llm = {
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question",
ConversationCommand.General: "Use this when you can answer the question without needing any additional online or personal information",
ConversationCommand.Default: "Use this if there might be a mix of general and personal knowledge in the question, or if you can't make sense of the query",
ConversationCommand.General: "Use this when you can answer the question without any outside information or personal knowledge",
ConversationCommand.Notes: "Use this when you would like to use the user's personal knowledge base to answer the question",
ConversationCommand.Online: "Use this when you would like to look up information on the internet",
}

View file

@ -21,6 +21,12 @@ class ConfigBase(BaseModel):
return setattr(self, key, value)
class LocationData(BaseModel):
city: Optional[str]
region: Optional[str]
country: Optional[str]
class TextConfigBase(ConfigBase):
compressed_jsonl: Path
embeddings_file: Path