mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Try respond even if document search via inference endpoint fails
The huggingface endpoint can be flaky. Khoj shouldn't refuse to respond to user if document search fails. It should transparently mention that document lookup failed. But try respond as best as it can without the document references This changes provides graceful failover when inference endpoint requests fail either when encoding query or reranking retrieved docs
This commit is contained in:
parent
9affeb9e85
commit
1b04b801c6
3 changed files with 35 additions and 22 deletions
|
@ -114,6 +114,7 @@ class CrossEncoderModel:
|
|||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()["scores"]
|
||||
|
||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||
|
|
|
@ -3,7 +3,6 @@ import base64
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
@ -839,25 +838,33 @@ async def chat(
|
|||
# Gather Context
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
q,
|
||||
(n or 7),
|
||||
d,
|
||||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
try:
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
q,
|
||||
(n or 7),
|
||||
d,
|
||||
conversation_id,
|
||||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
except Exception as e:
|
||||
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
logger.warning(error_message)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
):
|
||||
yield result
|
||||
|
||||
if not is_none_or_empty(compiled_references):
|
||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from sentence_transformers import util
|
||||
|
@ -231,8 +232,12 @@ def setup(
|
|||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||
"""Score all retrieved entries using the cross-encoder"""
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||
try:
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
|
||||
cross_scores = [0.0] * len(hits)
|
||||
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
|
|
Loading…
Reference in a new issue