Merge pull request #553 from khoj-ai/features/validation-errors

Update types of base config models for pydantic 2.0
This commit is contained in:
sabaimran 2023-11-18 00:42:56 -08:00 committed by GitHub
commit ebdb423d3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 34 deletions

View file

@ -151,17 +151,20 @@ def truncate_messages(
) )
system_message = messages.pop() system_message = messages.pop()
assert type(system_message.content) == str
system_message_tokens = len(encoder.encode(system_message.content)) system_message_tokens = len(encoder.encode(system_message.content))
tokens = sum([len(encoder.encode(message.content)) for message in messages]) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop() messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages]) assert type(system_message.content) == str
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Truncate current message if still over max supported prompt size by model # Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size: if (tokens + system_message_tokens) > max_prompt_size:
current_message = "\n".join(messages[0].content.split("\n")[:-1]) assert type(system_message.content) == str
original_question = "\n".join(messages[0].content.split("\n")[-1:]) current_message = "\n".join(messages[0].content.split("\n")[:-1]) if type(messages[0].content) == str else ""
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
original_question_tokens = len(encoder.encode(original_question)) original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()

View file

@ -296,7 +296,7 @@ async def get_all_filenames(
client=client, client=client,
) )
return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) return await sync_to_async(list)(EntryAdapters.aget_all_filenames_by_source(user, content_source)) # type: ignore[call-arg]
@api.post("/config/data/conversation/model", status_code=200) @api.post("/config/data/conversation/model", status_code=200)

View file

@ -63,7 +63,7 @@ async def update(
request: Request, request: Request,
files: list[UploadFile], files: list[UploadFile],
force: bool = False, force: bool = False,
t: Optional[Union[state.SearchType, str]] = None, t: Optional[Union[state.SearchType, str]] = state.SearchType.All,
client: Optional[str] = None, client: Optional[str] = None,
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
@ -182,13 +182,16 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]], files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels, search_models: SearchModels,
regenerate: bool = False, regenerate: bool = False,
t: Optional[state.SearchType] = None, t: Optional[state.SearchType] = state.SearchType.All,
full_corpus: bool = True, full_corpus: bool = True,
user: KhojUser = None, user: KhojUser = None,
) -> tuple[Optional[ContentIndex], bool]: ) -> tuple[Optional[ContentIndex], bool]:
content_index = ContentIndex() content_index = ContentIndex()
success = True success = True
if t is not None and t in [type.value for type in state.SearchType]:
t = state.SearchType(t)
if t is not None and not t.value in [type.value for type in state.SearchType]: if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}") logger.warning(f"🚨 Invalid search type: {t}")
return None, False return None, False
@ -201,7 +204,7 @@ def configure_content(
try: try:
# Initialize Org Notes Search # Initialize Org Notes Search
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes") logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
text_search.setup( text_search.setup(
@ -217,7 +220,9 @@ def configure_content(
try: try:
# Initialize Markdown Search # Initialize Markdown Search
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Markdown.value) and files[
"markdown"
]:
logger.info("💎 Setting up search for markdown notes") logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
text_search.setup( text_search.setup(
@ -234,7 +239,7 @@ def configure_content(
try: try:
# Initialize PDF Search # Initialize PDF Search
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf") logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings # Extract Entries, Generate PDF Embeddings
text_search.setup( text_search.setup(
@ -251,7 +256,9 @@ def configure_content(
try: try:
# Initialize Plaintext Search # Initialize Plaintext Search
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]: if (search_type == state.SearchType.All.value or search_type == state.SearchType.Plaintext.value) and files[
"plaintext"
]:
logger.info("📄 Setting up search for plaintext") logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings # Extract Entries, Generate Plaintext Embeddings
text_search.setup( text_search.setup(
@ -269,7 +276,7 @@ def configure_content(
try: try:
# Initialize Image Search # Initialize Image Search
if ( if (
(search_type == None or search_type == state.SearchType.Image.value) (search_type == state.SearchType.All.value or search_type == state.SearchType.Image.value)
and content_config and content_config
and content_config.image and content_config.image
and search_models.image_search and search_models.image_search
@ -286,7 +293,9 @@ def configure_content(
try: try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first() github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None: if (
search_type == state.SearchType.All.value or search_type == state.SearchType.Github.value
) and github_config is not None:
logger.info("🐙 Setting up search for github") logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings # Extract Entries, Generate Github Embeddings
text_search.setup( text_search.setup(
@ -305,7 +314,9 @@ def configure_content(
try: try:
# Initialize Notion Search # Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first() notion_config = NotionConfig.objects.filter(user=user).first()
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config: if (
search_type == state.SearchType.All.value or search_type in state.SearchType.Notion.value
) and notion_config:
logger.info("🔌 Setting up search for notion") logger.info("🔌 Setting up search for notion")
text_search.setup( text_search.setup(
NotionToEntries, NotionToEntries,

View file

@ -229,7 +229,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
# Add the image metadata to the results # Add the image metadata to the results
results += [ results += [
SearchResponse.parse_obj( SearchResponse.model_validate(
{ {
"entry": f"{image_files_url}/{target_image_name}", "entry": f"{image_files_url}/{target_image_name}",
"score": f"{hit['score']:.9f}", "score": f"{hit['score']:.9f}",
@ -237,7 +237,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
"image_score": f"{hit['image_score']:.9f}", "image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}", "metadata_score": f"{hit['metadata_score']:.9f}",
}, },
"corpus_id": hit["corpus_id"], "corpus_id": str(hit["corpus_id"]),
} }
) )
] ]

View file

@ -14,7 +14,7 @@ from khoj.utils.helpers import to_snake_case_from_dash
class ConfigBase(BaseModel): class ConfigBase(BaseModel):
class Config: class Config:
alias_generator = to_snake_case_from_dash alias_generator = to_snake_case_from_dash
allow_population_by_field_name = True populate_by_name = True
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
@ -29,8 +29,8 @@ class TextConfigBase(ConfigBase):
class TextContentConfig(ConfigBase): class TextContentConfig(ConfigBase):
input_files: Optional[List[Path]] input_files: Optional[List[Path]] = None
input_filter: Optional[List[str]] input_filter: Optional[List[str]] = None
index_heading_entries: Optional[bool] = False index_heading_entries: Optional[bool] = False
@ -50,31 +50,31 @@ class NotionContentConfig(ConfigBase):
class ImageContentConfig(ConfigBase): class ImageContentConfig(ConfigBase):
input_directories: Optional[List[Path]] input_directories: Optional[List[Path]] = None
input_filter: Optional[List[str]] input_filter: Optional[List[str]] = None
embeddings_file: Path embeddings_file: Path
use_xmp_metadata: bool use_xmp_metadata: bool
batch_size: int batch_size: int
class ContentConfig(ConfigBase): class ContentConfig(ConfigBase):
org: Optional[TextContentConfig] org: Optional[TextContentConfig] = None
image: Optional[ImageContentConfig] image: Optional[ImageContentConfig] = None
markdown: Optional[TextContentConfig] markdown: Optional[TextContentConfig] = None
pdf: Optional[TextContentConfig] pdf: Optional[TextContentConfig] = None
plaintext: Optional[TextContentConfig] plaintext: Optional[TextContentConfig] = None
github: Optional[GithubContentConfig] github: Optional[GithubContentConfig] = None
notion: Optional[NotionContentConfig] notion: Optional[NotionContentConfig] = None
class ImageSearchConfig(ConfigBase): class ImageSearchConfig(ConfigBase):
encoder: str encoder: str
encoder_type: Optional[str] encoder_type: Optional[str] = None
model_directory: Optional[Path] model_directory: Optional[Path] = None
class SearchConfig(ConfigBase): class SearchConfig(ConfigBase):
image: Optional[ImageSearchConfig] image: Optional[ImageSearchConfig] = None
class OpenAIProcessorConfig(ConfigBase): class OpenAIProcessorConfig(ConfigBase):
@ -95,7 +95,7 @@ class ConversationProcessorConfig(ConfigBase):
class ProcessorConfig(ConfigBase): class ProcessorConfig(ConfigBase):
conversation: Optional[ConversationProcessorConfig] conversation: Optional[ConversationProcessorConfig] = None
class AppConfig(ConfigBase): class AppConfig(ConfigBase):
@ -113,8 +113,8 @@ class FullConfig(ConfigBase):
class SearchResponse(ConfigBase): class SearchResponse(ConfigBase):
entry: str entry: str
score: float score: float
cross_score: Optional[float] cross_score: Optional[float] = None
additional: Optional[dict] additional: Optional[dict] = None
corpus_id: str corpus_id: str