[Fork] Batch embed by jwaltz ()

* refactor: convert chunk embedding to one API call

* chore: lint

* fix chroma for batch and single vectorization of text

* Fix LanceDB multi and single vectorization

* Fix pinecone for single and multiple embeddings

---------

Co-authored-by: Jonathan Waltz <volcanicislander@gmail.com>
This commit is contained in:
Timothy Carambat 2023-07-20 12:05:23 -07:00 committed by GitHub
parent 5a7d8add6f
commit c1deca4928
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 61 deletions
server/utils
helpers
vectorDbProviders
chroma
lance
pinecone

View file

@ -26,14 +26,24 @@ function curateSources(sources = []) {
const knownDocs = []; const knownDocs = [];
const documents = []; const documents = [];
// Sometimes the source may or may not have a metadata property
// in the response so we search for it explicitly or just spread the entire
// source and check to see if at least title exists.
for (const source of sources) { for (const source of sources) {
const { metadata = {} } = source; if (source.hasOwnProperty("metadata")) {
if ( const { metadata = {} } = source;
Object.keys(metadata).length > 0 && if (
!knownDocs.includes(metadata.title) Object.keys(metadata).length > 0 &&
) { !knownDocs.includes(metadata.title)
documents.push({ ...metadata }); ) {
knownDocs.push(metadata.title); documents.push({ ...metadata });
knownDocs.push(metadata.title);
}
} else {
if (Object.keys(source).length > 0 && !knownDocs.includes(source.title)) {
documents.push({ ...source });
knownDocs.push(source.title);
}
} }
} }

View file

@ -80,15 +80,20 @@ const Chroma = {
temperature, temperature,
}); });
}, },
embedChunk: async function (openai, textChunk) { embedTextInput: async function (openai, textInput) {
const result = await this.embedChunks(openai, textInput);
return result?.[0] || [];
},
embedChunks: async function (openai, chunks = []) {
const { const {
data: { data }, data: { data },
} = await openai.createEmbedding({ } = await openai.createEmbedding({
model: "text-embedding-ada-002", model: "text-embedding-ada-002",
input: textChunk, input: chunks,
}); });
return data.length > 0 && data[0].hasOwnProperty("embedding") return data.length > 0 &&
? data[0].embedding data.every((embd) => embd.hasOwnProperty("embedding"))
? data.map((embd) => embd.embedding)
: null; : null;
}, },
similarityResponse: async function (client, namespace, queryVector) { similarityResponse: async function (client, namespace, queryVector) {
@ -205,7 +210,7 @@ const Chroma = {
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const openai = this.openai(); const openai = this.openai();
const vectorValues = await this.embedChunks(openai, textChunks);
const submission = { const submission = {
ids: [], ids: [],
embeddings: [], embeddings: [],
@ -213,31 +218,29 @@ const Chroma = {
documents: [], documents: [],
}; };
for (const textChunk of textChunks) { if (!!vectorValues && vectorValues.length > 0) {
const vectorValues = await this.embedChunk(openai, textChunk); for (const [i, vector] of vectorValues.entries()) {
if (!!vectorValues) {
const vectorRecord = { const vectorRecord = {
id: uuidv4(), id: uuidv4(),
values: vectorValues, values: vector,
// [DO NOT REMOVE] // [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key. // LangChain will be unable to find your text if you embed manually and dont include the `text` key.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64 // https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
metadata: { ...metadata, text: textChunk }, metadata: { ...metadata, text: textChunks[i] },
}; };
submission.ids.push(vectorRecord.id); submission.ids.push(vectorRecord.id);
submission.embeddings.push(vectorRecord.values); submission.embeddings.push(vectorRecord.values);
submission.metadatas.push(metadata); submission.metadatas.push(metadata);
submission.documents.push(textChunk); submission.documents.push(textChunks[i]);
vectors.push(vectorRecord); vectors.push(vectorRecord);
documentVectors.push({ docId, vectorId: vectorRecord.id }); documentVectors.push({ docId, vectorId: vectorRecord.id });
} else {
console.error(
"Could not use OpenAI to embed document chunk! This document will not be recorded."
);
} }
} else {
console.error(
"Could not use OpenAI to embed document chunks! This document will not be recorded."
);
} }
const { client } = await this.connect(); const { client } = await this.connect();
@ -340,7 +343,7 @@ const Chroma = {
}; };
} }
const queryVector = await this.embedChunk(this.openai(), input); const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse( const { contextTexts, sourceDocuments } = await this.similarityResponse(
client, client,
namespace, namespace,

View file

@ -51,6 +51,22 @@ const LanceDb = {
process.env.OPEN_AI_KEY process.env.OPEN_AI_KEY
); );
}, },
embedTextInput: async function (openai, textInput) {
const result = await this.embedChunks(openai, textInput);
return result?.[0] || [];
},
embedChunks: async function (openai, chunks = []) {
const {
data: { data },
} = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: chunks,
});
return data.length > 0 &&
data.every((embd) => embd.hasOwnProperty("embedding"))
? data.map((embd) => embd.embedding)
: null;
},
embedder: function () { embedder: function () {
return new OpenAIEmbeddings({ openAIApiKey: process.env.OPEN_AI_KEY }); return new OpenAIEmbeddings({ openAIApiKey: process.env.OPEN_AI_KEY });
}, },
@ -59,17 +75,6 @@ const LanceDb = {
const openai = new OpenAIApi(config); const openai = new OpenAIApi(config);
return openai; return openai;
}, },
embedChunk: async function (openai, textChunk) {
const {
data: { data },
} = await openai.createEmbedding({
model: "text-embedding-ada-002",
input: textChunk,
});
return data.length > 0 && data[0].hasOwnProperty("embedding")
? data[0].embedding
: null;
},
getChatCompletion: async function ( getChatCompletion: async function (
openai, openai,
messages = [], messages = [],
@ -194,18 +199,17 @@ const LanceDb = {
const vectors = []; const vectors = [];
const submissions = []; const submissions = [];
const openai = this.openai(); const openai = this.openai();
const vectorValues = await this.embedChunks(openai, textChunks);
for (const textChunk of textChunks) { if (!!vectorValues && vectorValues.length > 0) {
const vectorValues = await this.embedChunk(openai, textChunk); for (const [i, vector] of vectorValues.entries()) {
if (!!vectorValues) {
const vectorRecord = { const vectorRecord = {
id: uuidv4(), id: uuidv4(),
values: vectorValues, values: vector,
// [DO NOT REMOVE] // [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key. // LangChain will be unable to find your text if you embed manually and dont include the `text` key.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64 // https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
metadata: { ...metadata, text: textChunk }, metadata: { ...metadata, text: textChunks[i] },
}; };
vectors.push(vectorRecord); vectors.push(vectorRecord);
@ -215,11 +219,11 @@ const LanceDb = {
...vectorRecord.metadata, ...vectorRecord.metadata,
}); });
documentVectors.push({ docId, vectorId: vectorRecord.id }); documentVectors.push({ docId, vectorId: vectorRecord.id });
} else {
console.error(
"Could not use OpenAI to embed document chunk! This document will not be recorded."
);
} }
} else {
console.error(
"Could not use OpenAI to embed document chunks! This document will not be recorded."
);
} }
if (vectors.length > 0) { if (vectors.length > 0) {
@ -253,7 +257,7 @@ const LanceDb = {
} }
// LanceDB does not have langchainJS support so we roll our own here. // LanceDB does not have langchainJS support so we roll our own here.
const queryVector = await this.embedChunk(this.openai(), input); const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse( const { contextTexts, sourceDocuments } = await this.similarityResponse(
client, client,
namespace, namespace,
@ -302,7 +306,7 @@ const LanceDb = {
}; };
} }
const queryVector = await this.embedChunk(this.openai(), input); const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse( const { contextTexts, sourceDocuments } = await this.similarityResponse(
client, client,
namespace, namespace,

View file

@ -54,15 +54,20 @@ const Pinecone = {
if (!data.hasOwnProperty("choices")) return null; if (!data.hasOwnProperty("choices")) return null;
return data.choices[0].message.content; return data.choices[0].message.content;
}, },
embedChunk: async function (openai, textChunk) { embedTextInput: async function (openai, textInput) {
const result = await this.embedChunks(openai, textInput);
return result?.[0] || [];
},
embedChunks: async function (openai, chunks = []) {
const { const {
data: { data }, data: { data },
} = await openai.createEmbedding({ } = await openai.createEmbedding({
model: "text-embedding-ada-002", model: "text-embedding-ada-002",
input: textChunk, input: chunks,
}); });
return data.length > 0 && data[0].hasOwnProperty("embedding") return data.length > 0 &&
? data[0].embedding data.every((embd) => embd.hasOwnProperty("embedding"))
? data.map((embd) => embd.embedding)
: null; : null;
}, },
llm: function ({ temperature = 0.7 }) { llm: function ({ temperature = 0.7 }) {
@ -175,25 +180,26 @@ const Pinecone = {
const documentVectors = []; const documentVectors = [];
const vectors = []; const vectors = [];
const openai = this.openai(); const openai = this.openai();
for (const textChunk of textChunks) { const vectorValues = await this.embedChunks(openai, textChunks);
const vectorValues = await this.embedChunk(openai, textChunk);
if (!!vectorValues) { if (!!vectorValues && vectorValues.length > 0) {
for (const [i, vector] of vectorValues.entries()) {
const vectorRecord = { const vectorRecord = {
id: uuidv4(), id: uuidv4(),
values: vectorValues, values: vector,
// [DO NOT REMOVE] // [DO NOT REMOVE]
// LangChain will be unable to find your text if you embed manually and dont include the `text` key. // LangChain will be unable to find your text if you embed manually and dont include the `text` key.
// https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64 // https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
metadata: { ...metadata, text: textChunk }, metadata: { ...metadata, text: textChunks[i] },
}; };
vectors.push(vectorRecord); vectors.push(vectorRecord);
documentVectors.push({ docId, vectorId: vectorRecord.id }); documentVectors.push({ docId, vectorId: vectorRecord.id });
} else {
console.error(
"Could not use OpenAI to embed document chunk! This document will not be recorded."
);
} }
} else {
console.error(
"Could not use OpenAI to embed document chunks! This document will not be recorded."
);
} }
if (vectors.length > 0) { if (vectors.length > 0) {
@ -311,7 +317,7 @@ const Pinecone = {
"Invalid namespace - has it been collected and seeded yet?" "Invalid namespace - has it been collected and seeded yet?"
); );
const queryVector = await this.embedChunk(this.openai(), input); const queryVector = await this.embedTextInput(this.openai(), input);
const { contextTexts, sourceDocuments } = await this.similarityResponse( const { contextTexts, sourceDocuments } = await this.similarityResponse(
pineconeIndex, pineconeIndex,
namespace, namespace,