mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge branch 'master' into features/advanced-reasoning
This commit is contained in:
commit
feb6d65ef8
9 changed files with 693 additions and 515 deletions
File diff suppressed because it is too large
Load diff
|
@ -696,10 +696,12 @@ class AgentAdapters:
|
|||
files: List[str],
|
||||
input_tools: List[str],
|
||||
output_modes: List[str],
|
||||
slug: Optional[str] = None,
|
||||
):
|
||||
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
|
||||
|
||||
agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create(
|
||||
# Slug will be None for new agents, which will trigger a new agent creation with a generated, immutable slug
|
||||
agent, created = await Agent.objects.filter(slug=slug, creator=user).aupdate_or_create(
|
||||
defaults={
|
||||
"name": name,
|
||||
"creator": user,
|
||||
|
|
|
@ -114,6 +114,7 @@ class CrossEncoderModel:
|
|||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()["scores"]
|
||||
|
||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||
|
|
|
@ -143,7 +143,6 @@ async def read_webpages(
|
|||
conversation_history: dict,
|
||||
location: LocationData,
|
||||
user: KhojUser,
|
||||
subscribed: bool = False,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
|
|
|
@ -35,6 +35,7 @@ class ModifyAgentBody(BaseModel):
|
|||
files: Optional[List[str]] = []
|
||||
input_tools: Optional[List[str]] = []
|
||||
output_modes: Optional[List[str]] = []
|
||||
slug: Optional[str] = None
|
||||
|
||||
|
||||
@api_agents.get("", response_class=Response)
|
||||
|
@ -192,6 +193,7 @@ async def create_agent(
|
|||
body.files,
|
||||
body.input_tools,
|
||||
body.output_modes,
|
||||
body.slug,
|
||||
)
|
||||
|
||||
agents_packet = {
|
||||
|
@ -233,7 +235,7 @@ async def update_agent(
|
|||
status_code=400,
|
||||
)
|
||||
|
||||
selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user)
|
||||
selected_agent = await AgentAdapters.aget_agent_by_slug(body.slug, user)
|
||||
|
||||
if not selected_agent:
|
||||
return Response(
|
||||
|
@ -253,6 +255,7 @@ async def update_agent(
|
|||
body.files,
|
||||
body.input_tools,
|
||||
body.output_modes,
|
||||
body.slug,
|
||||
)
|
||||
|
||||
agents_packet = {
|
||||
|
|
|
@ -213,7 +213,7 @@ def chat_history(
|
|||
|
||||
agent_metadata = None
|
||||
if conversation.agent:
|
||||
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE:
|
||||
if conversation.agent.privacy_level == Agent.PrivacyLevel.PRIVATE and conversation.agent.creator != user:
|
||||
conversation.agent = None
|
||||
else:
|
||||
agent_metadata = {
|
||||
|
@ -853,27 +853,36 @@ async def chat(
|
|||
return
|
||||
|
||||
# # Gather Context
|
||||
# async for result in extract_references_and_questions(
|
||||
# request,
|
||||
# meta_log,
|
||||
# q,
|
||||
# (n or 7),
|
||||
# d,
|
||||
# conversation_id,
|
||||
# conversation_commands,
|
||||
# location,
|
||||
# partial(send_event, ChatEvent.STATUS),
|
||||
# uploaded_image_url=uploaded_image_url,
|
||||
# agent=agent,
|
||||
# ):
|
||||
# if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
# yield result[ChatEvent.STATUS]
|
||||
# else:
|
||||
# compiled_references.extend(result[0])
|
||||
# inferred_queries.extend(result[1])
|
||||
# defiltered_query = result[2]
|
||||
|
||||
# if not is_none_or_empty(compiled_references):
|
||||
# # Extract Document References
|
||||
# try:
|
||||
# async for result in extract_references_and_questions(
|
||||
# request,
|
||||
# meta_log,
|
||||
# q,
|
||||
# (n or 7),
|
||||
# d,
|
||||
# conversation_id,
|
||||
# conversation_commands,
|
||||
# location,
|
||||
# partial(send_event, ChatEvent.STATUS),
|
||||
# uploaded_image_url=uploaded_image_url,
|
||||
# agent=agent,
|
||||
# ):
|
||||
# if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
# yield result[ChatEvent.STATUS]
|
||||
# else:
|
||||
# compiled_references.extend(result[0])
|
||||
# inferred_queries.extend(result[1])
|
||||
# defiltered_query = result[2]
|
||||
# except Exception as e:
|
||||
# error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
# logger.warning(error_message)
|
||||
# async for result in send_event(
|
||||
# ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
# ):
|
||||
# yield result
|
||||
#
|
||||
# # if not is_none_or_empty(compiled_references):
|
||||
# try:
|
||||
# headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
# # Strip only leading # from headings
|
||||
|
@ -910,12 +919,13 @@ async def chat(
|
|||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results = result
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||
logger.warning(error_message)
|
||||
async for result in send_llm_response(error_message):
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
||||
):
|
||||
yield result
|
||||
return
|
||||
|
||||
## Gather Webpage References
|
||||
if ConversationCommand.Webpage in conversation_commands and pending_research:
|
||||
|
@ -925,7 +935,6 @@ async def chat(
|
|||
meta_log,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
|
@ -945,11 +954,15 @@ async def chat(
|
|||
webpages.append(webpage["link"])
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error directly reading webpages: {e}. Attempting to respond without online results",
|
||||
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
||||
exc_info=True,
|
||||
)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
||||
):
|
||||
yield result
|
||||
|
||||
## Gather Code Results
|
||||
if ConversationCommand.Code in conversation_commands and pending_research:
|
||||
|
|
|
@ -345,13 +345,13 @@ async def aget_relevant_information_sources(
|
|||
final_response = [ConversationCommand.Default]
|
||||
else:
|
||||
final_response = [ConversationCommand.General]
|
||||
return final_response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
if len(agent_tools) == 0:
|
||||
final_response = [ConversationCommand.Default]
|
||||
else:
|
||||
final_response = agent_tools
|
||||
return final_response
|
||||
|
||||
|
||||
async def aget_relevant_output_modes(
|
||||
|
|
|
@ -227,7 +227,6 @@ async def execute_information_collection(
|
|||
conversation_history,
|
||||
location,
|
||||
user,
|
||||
subscribed,
|
||||
send_status_func,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from sentence_transformers import util
|
||||
|
@ -231,8 +232,12 @@ def setup(
|
|||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||
"""Score all retrieved entries using the cross-encoder"""
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||
try:
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
|
||||
cross_scores = [0.0] * len(hits)
|
||||
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
|
|
Loading…
Reference in a new issue