mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-28 01:45:07 +01:00
Improve comments, exceptions, typing and init of OpenAI model code
This commit is contained in:
parent
c0ae8eee99
commit
6119005838
1 changed files with 17 additions and 7 deletions
|
@ -4,25 +4,35 @@ import torch
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils.state import processor_config
|
from src.utils.state import processor_config, config_file
|
||||||
|
|
||||||
|
|
||||||
class OpenAI:
|
class OpenAI:
|
||||||
def __init__(self, model_name, device=None):
|
def __init__(self, model_name, device=None):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
if not processor_config or not processor_config.conversation or not processor_config.conversation.openai_api_key:
|
||||||
|
raise Exception(f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {config_file}")
|
||||||
openai.api_key = processor_config.conversation.openai_api_key
|
openai.api_key = processor_config.conversation.openai_api_key
|
||||||
self.embedding_dimensions = 1536 # Default to embedding dimensions of text-embedding-ada-002 model
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
def encode(self, entries, device=None, **kwargs):
|
def encode(self, entries: list[str], device=None, **kwargs):
|
||||||
embedding_tensors = []
|
embedding_tensors = []
|
||||||
|
|
||||||
for index in trange(0, len(entries)):
|
for index in trange(0, len(entries)):
|
||||||
|
# OpenAI models create better embeddings for entries without newlines
|
||||||
|
processed_entry = entries[index].replace('\n', ' ')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = openai.Embedding.create(input=entries[index], model=self.model_name)
|
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
|
||||||
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
||||||
self.embedding_dimensions = len(response.data[0].embedding) # Set embedding dimensions to this model's
|
# Use current models embedding dimension, once available
|
||||||
|
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
||||||
|
self.embedding_dimensions = len(response.data[0].embedding) if not self.embedding_dimensions else 1536
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}")
|
print(f"Failed to encode entry {index} of length: {len(entries[index])}\n\n{entries[index][:1000]}...\n\n{e}")
|
||||||
|
# Use zero embedding vector for entries with failed embeddings
|
||||||
|
# This ensures entry embeddings match the order of the source entries
|
||||||
|
# And they have minimal similarity to other entries (as zero vectors are always orthogonal to other vector)
|
||||||
embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)]
|
embedding_tensors += [torch.zeros(self.embedding_dimensions, device=device)]
|
||||||
return torch.stack(embedding_tensors)
|
|
||||||
|
|
||||||
|
|
||||||
|
return torch.stack(embedding_tensors)
|
Loading…
Reference in a new issue