Add support for Groq /models endpoint ()

* Add support for Groq /models endpoint

* linting
This commit is contained in:
Timothy Carambat 2024-07-24 08:35:52 -07:00 committed by GitHub
parent 23de85a3bd
commit 61e214aa8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 131 additions and 56 deletions
frontend/src
components/LLMSelection/GroqAiOptions
hooks
server/utils
AiProviders/groq
helpers

View file

@ -1,4 +1,10 @@
import { useState, useEffect } from "react";
import System from "@/models/system";
export default function GroqAiOptions({ settings }) {
const [inputValue, setInputValue] = useState(settings?.GroqApiKey);
const [apiKey, setApiKey] = useState(settings?.GroqApiKey);
return (
<div className="flex gap-[36px] mt-1.5">
<div className="flex flex-col w-60">
@ -8,41 +14,98 @@ export default function GroqAiOptions({ settings }) {
<input
type="password"
name="GroqApiKey"
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
className="border-none bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Groq API Key"
defaultValue={settings?.GroqApiKey ? "*".repeat(20) : ""}
required={true}
autoComplete="off"
spellCheck={false}
onChange={(e) => setInputValue(e.target.value)}
onBlur={() => setApiKey(inputValue)}
/>
</div>
{!settings?.credentialsOnly && (
<GroqAIModelSelection settings={settings} apiKey={apiKey} />
)}
</div>
);
}
function GroqAIModelSelection({ apiKey, settings }) {
const [customModels, setCustomModels] = useState([]);
const [loading, setLoading] = useState(true);
useEffect(() => {
async function findCustomModels() {
if (!apiKey) {
setCustomModels([]);
setLoading(true);
return;
}
try {
setLoading(true);
const { models } = await System.customModels("groq", apiKey);
setCustomModels(models || []);
} catch (error) {
console.error("Failed to fetch custom models:", error);
setCustomModels([]);
} finally {
setLoading(false);
}
}
findCustomModels();
}, [apiKey]);
if (loading) {
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GroqModelPref"
defaultValue={settings?.GroqModelPref || "llama3-8b-8192"}
required={true}
className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
disabled={true}
className="border-none bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
{[
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma-7b-it",
].map((model) => {
<option disabled={true} selected={true}>
--loading available models--
</option>
</select>
<p className="text-xs leading-[18px] font-base text-white text-opacity-60 mt-2">
Enter a valid API key to view all available models for your account.
</p>
</div>
);
}
return (
<option key={model} value={model}>
{model}
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GroqModelPref"
required={true}
className="border-none bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
defaultValue={settings?.GroqModelPref}
>
{customModels.length > 0 && (
<optgroup label="Available models">
{customModels.map((model) => {
return (
<option key={model.id} value={model.id}>
{model.id}
</option>
);
})}
</select>
</div>
</optgroup>
)}
</select>
<p className="text-xs leading-[18px] font-base text-white text-opacity-60 mt-2">
Select the GroqAI model you want to use for your conversations.
</p>
</div>
);
}

View file

@ -32,12 +32,7 @@ const PROVIDER_DEFAULT_MODELS = {
localai: [],
ollama: [],
togetherai: [],
groq: [
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma-7b-it",
],
groq: [],
native: [],
cohere: [
"command-r",

View file

@ -13,7 +13,7 @@ class GroqLLM {
apiKey: process.env.GROQ_API_KEY,
});
this.model =
modelPreference || process.env.GROQ_MODEL_PREF || "llama3-8b-8192";
modelPreference || process.env.GROQ_MODEL_PREF || "llama-3.1-8b-instant";
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,
@ -42,34 +42,24 @@ class GroqLLM {
promptWindowLimit() {
switch (this.model) {
case "mixtral-8x7b-32768":
return 32_768;
case "gemma2-9b-it":
case "gemma-7b-it":
case "llama3-70b-8192":
case "llama3-8b-8192":
return 8192;
case "llama3-70b-8192":
return 8192;
case "gemma-7b-it":
return 8192;
case "llama-3.1-70b-versatile":
case "llama-3.1-8b-instant":
case "llama-3.1-8b-instant":
return 131072;
case "mixtral-8x7b-32768":
return 32768;
default:
return 8192;
}
}
async isValidChatCompletionModel(modelName = "") {
const validModels = [
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma-7b-it",
];
const isPreset = validModels.some((model) => modelName === model);
if (isPreset) return true;
const model = await this.openai.models
.retrieve(modelName)
.then((modelObj) => modelObj)
.catch(() => null);
return !!model;
return !!modelName; // name just needs to exist
}
constructPrompt({

View file

@ -1,7 +1,4 @@
const {
OpenRouterLLM,
fetchOpenRouterModels,
} = require("../AiProviders/openRouter");
const { fetchOpenRouterModels } = require("../AiProviders/openRouter");
const { perplexityModels } = require("../AiProviders/perplexity");
const { togetherAiModels } = require("../AiProviders/togetherAi");
const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
@ -18,6 +15,7 @@ const SUPPORT_CUSTOM_MODELS = [
"koboldcpp",
"litellm",
"elevenlabs-tts",
"groq",
];
async function getCustomModels(provider = "", apiKey = null, basePath = null) {
@ -49,6 +47,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await liteLLMModels(basePath, apiKey);
case "elevenlabs-tts":
return await getElevenLabsModels(apiKey);
case "groq":
return await getGroqAiModels(apiKey);
default:
return { models: [], error: "Invalid provider for custom models" };
}
@ -167,6 +167,33 @@ async function localAIModels(basePath = null, apiKey = null) {
return { models, error: null };
}
async function getGroqAiModels(_apiKey = null) {
const { OpenAI: OpenAIApi } = require("openai");
const apiKey =
_apiKey === true
? process.env.GROQ_API_KEY
: _apiKey || process.env.GROQ_API_KEY || null;
const openai = new OpenAIApi({
baseURL: "https://api.groq.com/openai/v1",
apiKey,
});
const models = (
await openai.models
.list()
.then((results) => results.data)
.catch((e) => {
console.error(`GroqAi:listModels`, e.message);
return [];
})
).filter(
(model) => !model.id.includes("whisper") && !model.id.includes("tool-use")
);
// Api Key was successful so lets save it for future uses
if (models.length > 0 && !!apiKey) process.env.GROQ_API_KEY = apiKey;
return { models, error: null };
}
async function liteLLMModels(basePath = null, apiKey = null) {
const { OpenAI: OpenAIApi } = require("openai");
const openai = new OpenAIApi({