Agent support for LLMs with no function calling ()

* add LMStudio agent support (generic) support
"work" with non-tool callable LLMs, highly dependent on system specs

* add comments

* enable few-shot prompting per function for OSS models

* Add Agent support for Ollama models

* azure, groq, koboldcpp agent support complete + WIP togetherai

* WIP gemini agent support

* WIP gemini blocked and will not fix for now

* azure fix

* merge fix

* add localai agent support

* azure untooled agent support

* merge fix

* refactor implementation of several agent provideers

* update bad merge comment

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
This commit is contained in:
Sean Hatfield 2024-05-08 15:17:54 -07:00 committed by GitHub
parent b2b41db110
commit 8422f92542
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 645 additions and 8 deletions
.vscode
frontend/src/pages/WorkspaceSettings/AgentConfig/AgentLLMSelection
server/utils

View file

@ -28,6 +28,7 @@
"openrouter", "openrouter",
"Qdrant", "Qdrant",
"Serper", "Serper",
"togetherai",
"vectordbs", "vectordbs",
"Weaviate", "Weaviate",
"Zilliz" "Zilliz"

View file

@ -5,8 +5,25 @@ import { AVAILABLE_LLM_PROVIDERS } from "@/pages/GeneralSettings/LLMPreference";
import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react"; import { CaretUpDown, Gauge, MagnifyingGlass, X } from "@phosphor-icons/react";
import AgentModelSelection from "../AgentModelSelection"; import AgentModelSelection from "../AgentModelSelection";
const ENABLED_PROVIDERS = ["openai", "anthropic", "lmstudio", "ollama"]; const ENABLED_PROVIDERS = [
const WARN_PERFORMANCE = ["lmstudio", "ollama"]; "openai",
"anthropic",
"lmstudio",
"ollama",
"localai",
"groq",
"azure",
"koboldcpp",
"togetherai",
];
const WARN_PERFORMANCE = [
"lmstudio",
"groq",
"azure",
"koboldcpp",
"ollama",
"localai",
];
const LLM_DEFAULT = { const LLM_DEFAULT = {
name: "Please make a selection", name: "Please make a selection",

View file

@ -480,7 +480,7 @@ Read the following conversation.
CHAT HISTORY CHAT HISTORY
${history.map((c) => `@${c.from}: ${c.content}`).join("\n")} ${history.map((c) => `@${c.from}: ${c.content}`).join("\n")}
Then select the next role from that is going to speak next. Then select the next role from that is going to speak next.
Only return the role. Only return the role.
`, `,
}, },
@ -522,7 +522,7 @@ Only return the role.
? [ ? [
{ {
role: "user", role: "user",
content: `You are in a whatsapp group. Read the following conversation and then reply. content: `You are in a whatsapp group. Read the following conversation and then reply.
Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself. Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself.
CHAT HISTORY CHAT HISTORY
@ -743,6 +743,16 @@ ${this.getHistory({ to: route.to })
return new Providers.LMStudioProvider({}); return new Providers.LMStudioProvider({});
case "ollama": case "ollama":
return new Providers.OllamaProvider({ model: config.model }); return new Providers.OllamaProvider({ model: config.model });
case "groq":
return new Providers.GroqProvider({ model: config.model });
case "togetherai":
return new Providers.TogetherAIProvider({ model: config.model });
case "azure":
return new Providers.AzureOpenAiProvider({ model: config.model });
case "koboldcpp":
return new Providers.KoboldCPPProvider({});
case "localai":
return new Providers.LocalAIProvider({ model: config.model });
default: default:
throw new Error( throw new Error(

View file

@ -58,6 +58,9 @@ class Provider {
} }
} }
// For some providers we may want to override the system prompt to be more verbose.
// Currently we only do this for lmstudio, but we probably will want to expand this even more
// to any Untooled LLM.
static systemPrompt(provider = null) { static systemPrompt(provider = null) {
switch (provider) { switch (provider) {
case "lmstudio": case "lmstudio":

View file

@ -0,0 +1,105 @@
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the Azure OpenAI API.
*/
class AzureOpenAiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(_config = {}) {
super();
const client = new OpenAIClient(
process.env.AZURE_OPENAI_ENDPOINT,
new AzureKeyCredential(process.env.AZURE_OPENAI_KEY)
);
this._client = client;
this.model = process.env.OPEN_MODEL_PREF ?? "gpt-3.5-turbo";
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client
.getChatCompletions(this.model, messages, {
temperature: 0,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("Azure OpenAI chat: No results!");
if (result.choices.length === 0)
throw new Error("Azure OpenAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.getChatCompletions(
this.model,
this.cleanMsgs(messages),
{
temperature: 0.7,
}
);
completion = response.choices[0].message;
}
return { result: completion.content, cost: 0 };
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
* Stubbed since Azure OpenAI has no public cost basis.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
*/
getCost(_usage) {
return 0;
}
}
module.exports = AzureOpenAiProvider;

View file

@ -0,0 +1,110 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const { RetryError } = require("../error.js");
/**
* The provider for the Groq provider.
*/
class GroqProvider extends Provider {
model;
constructor(config = {}) {
const { model = "llama3-8b-8192" } = config;
const client = new OpenAI({
baseURL: "https://api.groq.com/openai/v1",
apiKey: process.env.GROQ_API_KEY,
maxRetries: 3,
});
super(client);
this.model = model;
this.verbose = true;
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
const response = await this.client.chat.completions.create({
model: this.model,
// stream: true,
messages,
...(Array.isArray(functions) && functions?.length > 0
? { functions }
: {}),
});
// Right now, we only support one completion,
// so we just take the first one in the list
const completion = response.choices[0].message;
const cost = this.getCost(response.usage);
// treat function calls
if (completion.function_call) {
let functionArgs = {};
try {
functionArgs = JSON.parse(completion.function_call.arguments);
} catch (error) {
// call the complete function again in case it gets a json error
return this.complete(
[
...messages,
{
role: "function",
name: completion.function_call.name,
function_call: completion.function_call,
content: error?.message,
},
],
functions
);
}
// console.log(completion, { functionArgs })
return {
result: null,
functionCall: {
name: completion.function_call.name,
arguments: functionArgs,
},
cost,
};
}
return {
result: completion.content,
cost,
};
} catch (error) {
// If invalid Auth error we need to abort because no amount of waiting
// will make auth better.
if (error instanceof OpenAI.AuthenticationError) throw error;
if (
error instanceof OpenAI.RateLimitError ||
error instanceof OpenAI.InternalServerError ||
error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
) {
throw new RetryError(error.message);
}
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since Groq has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = GroqProvider;

View file

@ -110,7 +110,7 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
const response = await chatCb({ const response = await chatCb({
messages: [ messages: [
{ {
content: `You are a program which picks the most optimal function and parameters to call. content: `You are a program which picks the most optimal function and parameters to call.
DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY. DO NOT HAVE TO PICK A FUNCTION IF IT WILL NOT HELP ANSWER OR FULFILL THE USER'S QUERY.
When a function is selection, respond in JSON with no additional text. When a function is selection, respond in JSON with no additional text.
When there is no relevant function to call - return with a regular chat text response. When there is no relevant function to call - return with a regular chat text response.
@ -130,7 +130,6 @@ ${JSON.stringify(def.parameters.properties, null, 4)}\n`;
...history, ...history,
], ],
}); });
const call = safeJsonParse(response, null); const call = safeJsonParse(response, null);
if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text. if (call === null) return { toolCall: null, text: response }; // failed to parse, so must be text.

View file

@ -2,10 +2,20 @@ const OpenAIProvider = require("./openai.js");
const AnthropicProvider = require("./anthropic.js"); const AnthropicProvider = require("./anthropic.js");
const LMStudioProvider = require("./lmstudio.js"); const LMStudioProvider = require("./lmstudio.js");
const OllamaProvider = require("./ollama.js"); const OllamaProvider = require("./ollama.js");
const GroqProvider = require("./groq.js");
const TogetherAIProvider = require("./togetherai.js");
const AzureOpenAiProvider = require("./azure.js");
const KoboldCPPProvider = require("./koboldcpp.js");
const LocalAIProvider = require("./localai.js");
module.exports = { module.exports = {
OpenAIProvider, OpenAIProvider,
AnthropicProvider, AnthropicProvider,
LMStudioProvider, LMStudioProvider,
OllamaProvider, OllamaProvider,
GroqProvider,
TogetherAIProvider,
AzureOpenAiProvider,
KoboldCPPProvider,
LocalAIProvider,
}; };

View file

@ -0,0 +1,113 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the KoboldCPP provider.
*/
class KoboldCPPProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(_config = {}) {
super();
const model = process.env.KOBOLD_CPP_MODEL_PREF ?? null;
const client = new OpenAI({
baseURL: process.env.KOBOLD_CPP_BASE_PATH?.replace(/\/+$/, ""),
apiKey: null,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("KoboldCPP chat: No results!");
if (result.choices.length === 0)
throw new Error("KoboldCPP chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since KoboldCPP has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = KoboldCPPProvider;

View file

@ -16,8 +16,8 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), // here is the URL to your LMStudio instance
apiKey: null, apiKey: null,
maxRetries: 3, maxRetries: 3,
model,
}); });
this._client = client; this._client = client;
this.model = model; this.model = model;
this.verbose = true; this.verbose = true;

View file

@ -0,0 +1,114 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the LocalAI provider.
*/
class LocalAiProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = null } = config;
super();
const client = new OpenAI({
baseURL: process.env.LOCAL_AI_BASE_PATH,
apiKey: process.env.LOCAL_AI_API_KEY ?? null,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LocalAI chat: No results!");
if (result.choices.length === 0)
throw new Error("LocalAI chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return { result: completion.content, cost: 0 };
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LocalAI has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = LocalAiProvider;

View file

@ -0,0 +1,113 @@
const OpenAI = require("openai");
const Provider = require("./ai-provider.js");
const InheritMultiple = require("./helpers/classes.js");
const UnTooled = require("./helpers/untooled.js");
/**
* The provider for the TogetherAI provider.
*/
class TogetherAIProvider extends InheritMultiple([Provider, UnTooled]) {
model;
constructor(config = {}) {
const { model = "mistralai/Mistral-7B-Instruct-v0.1" } = config;
super();
const client = new OpenAI({
baseURL: "https://api.together.xyz/v1",
apiKey: process.env.TOGETHER_AI_API_KEY,
maxRetries: 3,
});
this._client = client;
this.model = model;
this.verbose = true;
}
get client() {
return this._client;
}
async #handleFunctionCallChat({ messages = [] }) {
return await this.client.chat.completions
.create({
model: this.model,
temperature: 0,
messages,
})
.then((result) => {
if (!result.hasOwnProperty("choices"))
throw new Error("LMStudio chat: No results!");
if (result.choices.length === 0)
throw new Error("LMStudio chat: No results length!");
return result.choices[0].message.content;
})
.catch((_) => {
return null;
});
}
/**
* Create a completion based on the received messages.
*
* @param messages A list of messages to send to the API.
* @param functions
* @returns The completion.
*/
async complete(messages, functions = null) {
try {
let completion;
if (functions.length > 0) {
const { toolCall, text } = await this.functionCall(
messages,
functions,
this.#handleFunctionCallChat.bind(this)
);
if (toolCall !== null) {
this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
return {
result: null,
functionCall: {
name: toolCall.name,
arguments: toolCall.arguments,
},
cost: 0,
};
}
completion = { content: text };
}
if (!completion?.content) {
this.providerLog(
"Will assume chat completion without tool call inputs."
);
const response = await this.client.chat.completions.create({
model: this.model,
messages: this.cleanMsgs(messages),
});
completion = response.choices[0].message;
}
return {
result: completion.content,
cost: 0,
};
} catch (error) {
throw error;
}
}
/**
* Get the cost of the completion.
*
* @param _usage The completion to get the cost for.
* @returns The cost of the completion.
* Stubbed since LMStudio has no cost basis.
*/
getCost(_usage) {
return 0;
}
}
module.exports = TogetherAIProvider;

View file

@ -85,6 +85,36 @@ class AgentHandler {
if (!process.env.OLLAMA_BASE_PATH) if (!process.env.OLLAMA_BASE_PATH)
throw new Error("Ollama base path must be provided to use agents."); throw new Error("Ollama base path must be provided to use agents.");
break; break;
case "groq":
if (!process.env.GROQ_API_KEY)
throw new Error("Groq API key must be provided to use agents.");
break;
case "togetherai":
if (!process.env.TOGETHER_AI_API_KEY)
throw new Error("TogetherAI API key must be provided to use agents.");
break;
case "azure":
if (!process.env.AZURE_OPENAI_ENDPOINT || !process.env.AZURE_OPENAI_KEY)
throw new Error(
"Azure OpenAI API endpoint and key must be provided to use agents."
);
break;
case "koboldcpp":
if (!process.env.KOBOLD_CPP_BASE_PATH)
throw new Error(
"KoboldCPP must have a valid base path to use for the api."
);
break;
case "localai":
if (!process.env.LOCAL_AI_BASE_PATH)
throw new Error(
"LocalAI must have a valid base path to use for the api."
);
break;
case "gemini":
if (!process.env.GEMINI_API_KEY)
throw new Error("Gemini API key must be provided to use agents.");
break;
default: default:
throw new Error("No provider found to power agent cluster."); throw new Error("No provider found to power agent cluster.");
} }
@ -100,6 +130,18 @@ class AgentHandler {
return "server-default"; return "server-default";
case "ollama": case "ollama":
return "llama3:latest"; return "llama3:latest";
case "groq":
return "llama3-70b-8192";
case "togetherai":
return "mistralai/Mixtral-8x7B-Instruct-v0.1";
case "azure":
return "gpt-3.5-turbo";
case "koboldcpp":
return null;
case "gemini":
return "gemini-pro";
case "localai":
return null;
default: default:
return "unknown"; return "unknown";
} }

View file

@ -178,7 +178,7 @@ async function getKoboldCPPModels(basePath = null) {
try { try {
const { OpenAI: OpenAIApi } = require("openai"); const { OpenAI: OpenAIApi } = require("openai");
const openai = new OpenAIApi({ const openai = new OpenAIApi({
baseURL: basePath || process.env.LMSTUDIO_BASE_PATH, baseURL: basePath || process.env.KOBOLD_CPP_BASE_PATH,
apiKey: null, apiKey: null,
}); });
const models = await openai.models const models = await openai.models