Wrap common API query parameters into shared class to deduplicate code

- Upgrade FastAPI to >= latest version. Required upgrade of FastAPI.
  Earlier version didn't support wrapping common query params in class

- Use per fixture app instead of a global FastAPI app in conftest

- Upgrade minimum required Django version

- Fix no notes chat director test with updated no notes message
  No notes message was updated in commit 118f1143
This commit is contained in:
Debanjum Singh Solanky 2023-11-17 18:22:45 -08:00
parent 68ac1e0193
commit ca87b4ede9
5 changed files with 38 additions and 50 deletions

View file

@ -39,7 +39,7 @@ dependencies = [
"bs4 >= 0.0.1",
"dateparser >= 1.1.1",
"defusedxml == 0.7.1",
"fastapi == 0.77.1",
"fastapi >= 0.104.1",
"python-multipart >= 0.0.5",
"jinja2 == 3.1.2",
"openai >= 0.27.0, < 1.0.0",
@ -60,7 +60,7 @@ dependencies = [
"bs4 >= 0.0.1",
"anyio == 3.7.1",
"pymupdf >= 1.23.5",
"django == 4.2.5",
"django == 4.2.7",
"authlib == 1.2.1",
"gpt4all >= 2.0.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",

View file

@ -4,7 +4,7 @@ import math
import time
import logging
import json
from typing import List, Optional, Union, Any
from typing import Annotated, List, Optional, Union, Any
# External Packages
from fastapi import APIRouter, Depends, HTTPException, Header, Request
@ -31,6 +31,7 @@ from khoj.utils import state, constants
from khoj.utils.helpers import AsyncIteratorWrapper, get_device
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
CommonQueryParams,
get_conversation_command,
validate_conversation_config,
agenerate_chat_response,
@ -354,15 +355,12 @@ def get_config_types(
async def search(
q: str,
request: Request,
common: CommonQueryParams,
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object
start_time = time.time()
@ -466,10 +464,7 @@ async def search(
request=request,
telemetry_type="api",
api="search",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
end_time = time.time()
@ -482,12 +477,9 @@ async def search(
@requires(["authenticated"])
def update(
request: Request,
common: CommonQueryParams,
t: Optional[SearchType] = None,
force: Optional[bool] = False,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
user = request.user.object
if not state.config:
@ -513,10 +505,7 @@ def update(
request=request,
telemetry_type="api",
api="update",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
return {"status": "ok", "message": "khoj reloaded"}
@ -526,10 +515,7 @@ def update(
@requires(["authenticated"])
def chat_history(
request: Request,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
common: CommonQueryParams,
):
user = request.user.object
validate_conversation_config()
@ -541,10 +527,7 @@ def chat_history(
request=request,
telemetry_type="api",
api="chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
return {"status": "ok", "response": meta_log.get("chat", [])}
@ -554,10 +537,7 @@ def chat_history(
@requires(["authenticated"])
async def chat_options(
request: Request,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
common: CommonQueryParams,
) -> Response:
cmd_options = {}
for cmd in ConversationCommand:
@ -567,10 +547,7 @@ async def chat_options(
request=request,
telemetry_type="api",
api="chat_options",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
**common.__dict__,
)
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
@ -579,14 +556,11 @@ async def chat_options(
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.18,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
) -> Response:
@ -600,7 +574,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, meta_log, q, (n or 5), (d or math.inf), conversation_command
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@ -634,11 +608,8 @@ async def chat(
request=request,
telemetry_type="api",
api="chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
metadata=chat_metadata,
**common.__dict__,
)
if llm_response is None:
@ -665,6 +636,7 @@ async def chat(
async def extract_references_and_questions(
request: Request,
common: CommonQueryParams,
meta_log: dict,
q: str,
n: int,
@ -731,6 +703,7 @@ async def extract_references_and_questions(
r=True,
max_distance=d,
dedupe=False,
common=common,
)
)
# Dedupe the results again, as duplicates may be returned across queries.

View file

@ -6,10 +6,10 @@ from datetime import datetime
from functools import partial
import logging
from time import time
from typing import Iterator, List, Optional, Union, Tuple, Dict
from typing import Annotated, Iterator, List, Optional, Union, Tuple, Dict
# External Packages
from fastapi import HTTPException, Request
from fastapi import HTTPException, Header, Request, Depends
# Internal Packages
from khoj.utils import state
@ -221,3 +221,20 @@ class ApiUserRateLimiter:
# Add the current request to the cache
user_requests.append(time())
class CommonQueryParamsClass:
def __init__(
self,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
self.client = client
self.user_agent = user_agent
self.referer = referer
self.host = host
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]

View file

@ -9,9 +9,6 @@ import os
from fastapi import FastAPI
app = FastAPI()
# Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
@ -320,6 +317,7 @@ def client(
state.anonymous_mode = False
app = FastAPI()
configure_routes(app)
configure_middleware(app)
app.mount("/static", StaticFiles(directory=web_directory), name="static")

View file

@ -227,7 +227,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
# Assert
assert response.status_code == 200
assert response_message == prompts.no_notes_found.format()
assert response_message == prompts.no_entries_found.format()
# ----------------------------------------------------------------------------------------------------