mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Make Online Search Location Aware (#929)
## Overview Add user country code as context for doing online search with serper.dev API. This should find more user relevant results from online searches by Khoj ## Details ### Major - Default to using system clock to infer user timezone on js clients - Infer country from timezone when only timezone received by chat API - Localize online search results to user country when location available ### Minor - Add `__str__` func to `LocationData` class to deduplicate location string generation
This commit is contained in:
commit
4a1cb50da3
16 changed files with 95 additions and 53 deletions
|
@ -60,7 +60,8 @@
|
|||
let region = null;
|
||||
let city = null;
|
||||
let countryName = null;
|
||||
let timezone = null;
|
||||
let countryCode = null;
|
||||
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
let chatMessageState = {
|
||||
newResponseTextEl: null,
|
||||
newResponseEl: null,
|
||||
|
@ -76,6 +77,7 @@
|
|||
region = data.region;
|
||||
city = data.city;
|
||||
countryName = data.country_name;
|
||||
countryCode = data.country_code;
|
||||
timezone = data.timezone;
|
||||
})
|
||||
.catch(err => {
|
||||
|
@ -157,6 +159,7 @@
|
|||
...(!!city && { city: city }),
|
||||
...(!!region && { region: region }),
|
||||
...(!!countryName && { country: countryName }),
|
||||
...(!!countryCode && { country_code: countryCode }),
|
||||
...(!!timezone && { timezone: timezone }),
|
||||
};
|
||||
|
||||
|
|
|
@ -308,18 +308,19 @@
|
|||
<script src="./utils.js"></script>
|
||||
<script src="./chatutils.js"></script>
|
||||
<script>
|
||||
|
||||
let region = null;
|
||||
let city = null;
|
||||
let countryName = null;
|
||||
let timezone = null;
|
||||
let countryCode = null;
|
||||
let timezone = Intl.DateTimeFormat().resolvedOptions().timeZone;
|
||||
|
||||
fetch("https://ipapi.co/json")
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
region = data.region;
|
||||
city = data.city;
|
||||
region = data.region;
|
||||
countryName = data.country_name;
|
||||
countryCode = data.country_code;
|
||||
timezone = data.timezone;
|
||||
})
|
||||
.catch(err => {
|
||||
|
@ -410,6 +411,7 @@
|
|||
...(!!city && { city: city }),
|
||||
...(!!region && { region: region }),
|
||||
...(!!countryName && { country: countryName }),
|
||||
...(!!countryCode && { country_code: countryCode }),
|
||||
...(!!timezone && { timezone: timezone }),
|
||||
};
|
||||
|
||||
|
|
|
@ -33,9 +33,10 @@ interface ChatMessageState {
|
|||
}
|
||||
|
||||
interface Location {
|
||||
region: string;
|
||||
city: string;
|
||||
countryName: string;
|
||||
region?: string;
|
||||
city?: string;
|
||||
countryName?: string;
|
||||
countryCode?: string;
|
||||
timezone: string;
|
||||
}
|
||||
|
||||
|
@ -43,7 +44,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
result: string;
|
||||
setting: KhojSetting;
|
||||
waitingForLocation: boolean;
|
||||
location: Location;
|
||||
location: Location = { timezone: Intl.DateTimeFormat().resolvedOptions().timeZone };
|
||||
keyPressTimeout: NodeJS.Timeout | null = null;
|
||||
userMessages: string[] = []; // Store user sent messages for input history cycling
|
||||
currentMessageIndex: number = -1; // Track current message index in userMessages array
|
||||
|
@ -70,6 +71,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
region: data.region,
|
||||
city: data.city,
|
||||
countryName: data.country_name,
|
||||
countryCode: data.country_code,
|
||||
timezone: data.timezone,
|
||||
};
|
||||
})
|
||||
|
@ -1056,12 +1058,11 @@ export class KhojChatView extends KhojPaneView {
|
|||
n: this.setting.resultsCount,
|
||||
stream: true,
|
||||
...(!!conversationId && { conversation_id: conversationId }),
|
||||
...(!!this.location && {
|
||||
city: this.location.city,
|
||||
region: this.location.region,
|
||||
country: this.location.countryName,
|
||||
timezone: this.location.timezone,
|
||||
}),
|
||||
...(!!this.location && this.location.city && { city: this.location.city }),
|
||||
...(!!this.location && this.location.region && { region: this.location.region }),
|
||||
...(!!this.location && this.location.countryName && { country: this.location.countryName }),
|
||||
...(!!this.location && this.location.countryCode && { country_code: this.location.countryCode }),
|
||||
...(!!this.location && this.location.timezone && { timezone: this.location.timezone }),
|
||||
};
|
||||
|
||||
let newResponseEl = this.createKhojResponseDiv();
|
||||
|
|
|
@ -518,12 +518,14 @@ function EditCard(props: EditCardProps) {
|
|||
updateQueryUrl += `&subject=${encodeURIComponent(values.subject)}`;
|
||||
}
|
||||
updateQueryUrl += `&crontime=${encodeURIComponent(cronFrequency)}`;
|
||||
if (props.locationData) {
|
||||
if (props.locationData && props.locationData.city)
|
||||
updateQueryUrl += `&city=${encodeURIComponent(props.locationData.city)}`;
|
||||
if (props.locationData && props.locationData.region)
|
||||
updateQueryUrl += `®ion=${encodeURIComponent(props.locationData.region)}`;
|
||||
if (props.locationData && props.locationData.country)
|
||||
updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`;
|
||||
if (props.locationData && props.locationData.timezone)
|
||||
updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`;
|
||||
}
|
||||
|
||||
let method = props.createNew ? "POST" : "PUT";
|
||||
|
||||
|
|
|
@ -136,7 +136,9 @@ export default function Chat() {
|
|||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
|
||||
const locationData = useIPLocationData();
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
};
|
||||
const authenticatedData = useAuthenticatedData();
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
|
@ -241,9 +243,10 @@ export default function Chat() {
|
|||
conversation_id: conversationId,
|
||||
stream: true,
|
||||
...(locationData && {
|
||||
city: locationData.city,
|
||||
region: locationData.region,
|
||||
country: locationData.country,
|
||||
city: locationData.city,
|
||||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
|
|
|
@ -2,13 +2,10 @@ import { useEffect, useState } from "react";
|
|||
import useSWR from "swr";
|
||||
|
||||
export interface LocationData {
|
||||
ip: string;
|
||||
city: string;
|
||||
region: string;
|
||||
country: string;
|
||||
postal: string;
|
||||
latitude: number;
|
||||
longitude: number;
|
||||
city?: string;
|
||||
region?: string;
|
||||
country?: string;
|
||||
countryCode?: string;
|
||||
timezone: string;
|
||||
}
|
||||
|
||||
|
@ -50,9 +47,7 @@ export function useIPLocationData() {
|
|||
{ revalidateOnFocus: false },
|
||||
);
|
||||
|
||||
if (locationDataError) return null;
|
||||
if (!locationData) return null;
|
||||
|
||||
if (locationDataError || !locationData) return;
|
||||
return locationData;
|
||||
}
|
||||
|
||||
|
|
|
@ -111,7 +111,9 @@ export default function SharedChat() {
|
|||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
|
||||
const locationData = useIPLocationData();
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
};
|
||||
const authenticatedData = useAuthenticatedData();
|
||||
const isMobileWidth = useIsMobileWidth();
|
||||
|
||||
|
@ -231,6 +233,7 @@ export default function SharedChat() {
|
|||
region: locationData.region,
|
||||
country: locationData.country,
|
||||
city: locationData.city,
|
||||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
|
|
|
@ -32,7 +32,7 @@ def extract_questions_anthropic(
|
|||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
|
@ -158,8 +158,7 @@ def converse_anthropic(
|
|||
)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
|
|
|
@ -33,7 +33,7 @@ def extract_questions_gemini(
|
|||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
|
@ -163,8 +163,7 @@ def converse_gemini(
|
|||
)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
|
|
|
@ -46,7 +46,7 @@ def extract_questions_offline(
|
|||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
|
@ -171,8 +171,7 @@ def converse_offline(
|
|||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
|
|
|
@ -36,7 +36,7 @@ def extract_questions(
|
|||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
|
@ -159,8 +159,7 @@ def converse(
|
|||
)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
|
|
|
@ -7,7 +7,6 @@ from collections import defaultdict
|
|||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
|
@ -80,7 +79,7 @@ async def search_online(
|
|||
|
||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||
search_tasks = [search_func(subquery) for subquery in subqueries]
|
||||
search_tasks = [search_func(subquery, location) for subquery in subqueries]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
||||
|
||||
|
@ -115,8 +114,9 @@ async def search_online(
|
|||
yield response_dict
|
||||
|
||||
|
||||
async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||
payload = json.dumps({"q": query})
|
||||
async def search_with_google(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||
country_code = location.country_code.lower() if location and location.country_code else "us"
|
||||
payload = json.dumps({"q": query, "gl": country_code})
|
||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
@ -220,7 +220,7 @@ async def read_webpage_with_jina(web_url: str) -> str:
|
|||
return response_json["data"]["content"]
|
||||
|
||||
|
||||
async def search_with_jina(query: str) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||
encoded_query = urllib.parse.quote(query)
|
||||
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
|
||||
headers = {"Accept": "application/json"}
|
||||
|
|
|
@ -55,6 +55,8 @@ from khoj.utils.helpers import (
|
|||
ConversationCommand,
|
||||
command_descriptions,
|
||||
convert_image_to_webp,
|
||||
get_country_code_from_timezone,
|
||||
get_country_name_from_timezone,
|
||||
get_device,
|
||||
is_none_or_empty,
|
||||
)
|
||||
|
@ -529,6 +531,7 @@ class ChatRequestBody(BaseModel):
|
|||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
@ -556,7 +559,8 @@ async def chat(
|
|||
conversation_id = body.conversation_id
|
||||
city = body.city
|
||||
region = body.region
|
||||
country = body.country
|
||||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
image = body.image
|
||||
|
||||
|
@ -658,8 +662,8 @@ async def chat(
|
|||
|
||||
user_name = await aget_user_name(user)
|
||||
location = None
|
||||
if city or region or country:
|
||||
location = LocationData(city=city, region=region, country=country)
|
||||
if city or region or country or country_code:
|
||||
location = LocationData(city=city, region=region, country=country, country_code=country_code)
|
||||
|
||||
if is_query_empty(q):
|
||||
async for result in send_llm_response("Please ask your query to get started."):
|
||||
|
|
|
@ -369,7 +369,7 @@ async def infer_webpage_urls(
|
|||
"""
|
||||
Infer webpage links from the given query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
|
@ -405,7 +405,7 @@ async def generate_online_subqueries(
|
|||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
|
@ -535,8 +535,7 @@ async def generate_better_image_prompt(
|
|||
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
else:
|
||||
location_prompt = "Unknown"
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import random
|
|||
import uuid
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from importlib import import_module
|
||||
from importlib.metadata import version
|
||||
from itertools import islice
|
||||
|
@ -24,6 +25,7 @@ import torch
|
|||
from asgiref.sync import sync_to_async
|
||||
from magika import Magika
|
||||
from PIL import Image
|
||||
from pytz import country_names, country_timezones
|
||||
|
||||
from khoj.utils import constants
|
||||
|
||||
|
@ -431,3 +433,24 @@ def convert_image_to_webp(image_bytes):
|
|||
webp_image_bytes = webp_image_io.getvalue()
|
||||
webp_image_io.close()
|
||||
return webp_image_bytes
|
||||
|
||||
|
||||
@lru_cache
|
||||
def tz_to_cc_map() -> dict[str, str]:
|
||||
"""Create a mapping of timezone to country code"""
|
||||
timezone_country = {}
|
||||
for countrycode in country_timezones:
|
||||
timezones = country_timezones[countrycode]
|
||||
for timezone in timezones:
|
||||
timezone_country[timezone] = countrycode
|
||||
return timezone_country
|
||||
|
||||
|
||||
def get_country_code_from_timezone(tz: str) -> str:
|
||||
"""Get country code from timezone"""
|
||||
return tz_to_cc_map().get(tz, "US")
|
||||
|
||||
|
||||
def get_country_name_from_timezone(tz: str) -> str:
|
||||
"""Get country name from timezone"""
|
||||
return country_names.get(get_country_code_from_timezone(tz), "United States")
|
||||
|
|
|
@ -25,6 +25,17 @@ class LocationData(BaseModel):
|
|||
city: Optional[str]
|
||||
region: Optional[str]
|
||||
country: Optional[str]
|
||||
country_code: Optional[str]
|
||||
|
||||
def __str__(self):
|
||||
parts = []
|
||||
if self.city:
|
||||
parts.append(self.city)
|
||||
if self.region:
|
||||
parts.append(self.region)
|
||||
if self.country:
|
||||
parts.append(self.country)
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
class FileFilterRequest(BaseModel):
|
||||
|
|
Loading…
Reference in a new issue