This commit is contained in:
Ryan Huang
2024-04-16 18:31:56 -04:00
parent e029b80cf6
commit ed73fd3747
6 changed files with 242 additions and 43 deletions

View File

@ -9,10 +9,13 @@
"concatenator",
"Fcfs",
"GGUF",
"gptneox",
"immer",
"inferencing",
"lmstudio",
"proxying",
"replit",
"starcoder",
"Streamable"
],
"rewrap.wrappingColumn": 100,

View File

@ -0,0 +1,128 @@
import chalk from "chalk";
export class InfoLookup<TInnerKey, TLookupKey, TValue> {
private readonly lookup = new Map<TInnerKey, TValue>();
private readonly fallback: (key: TLookupKey) => TValue;
private constructor(
private readonly keyMapper: (key: TLookupKey) => TInnerKey,
fallback: ((key: TLookupKey) => TValue) | undefined,
) {
this.fallback =
fallback ??
(key => {
throw new Error(`Key not found: ${key}`);
});
}
public static create<TKey, TValue>({ fallback }: { fallback?: (key: TKey) => TValue } = {}) {
return new InfoLookup<TKey, TKey, TValue>(key => key, fallback);
}
public static createWithKeyMapper<TInnerKey, TLookupKey, TValue>({
fallback,
keyMapper,
}: {
fallback: (key: TLookupKey) => TValue;
keyMapper: (key: TLookupKey) => TInnerKey;
}) {
return new InfoLookup<TInnerKey, TLookupKey, TValue>(keyMapper, fallback);
}
public register(...args: [...TInnerKey[], TValue]): this {
const value = args.at(-1) as TValue;
for (let i = 0; i < args.length - 1; i++) {
this.lookup.set(args[i] as TInnerKey, value);
}
return this;
}
public find(lookupKey: TLookupKey): TValue {
const innerKey = this.keyMapper(lookupKey);
if (this.lookup.has(innerKey)) {
return this.lookup.get(innerKey)!;
} else {
return this.fallback(lookupKey);
}
}
}
const llmColorer = chalk.cyanBright;
const visionColorer = chalk.yellowBright;
const embeddingColorer = chalk.blueBright;
export const architectureInfoLookup = InfoLookup.createWithKeyMapper({
fallback: (arch: string) => ({
name: arch,
colorer: llmColorer,
}),
keyMapper: (arch: string) => arch.toLowerCase(),
})
.register("phi2", "phi-2", {
name: "Phi-2",
colorer: llmColorer,
})
.register("mistral", {
name: "Mistral",
colorer: llmColorer,
})
.register("llama", {
name: "Llama",
colorer: llmColorer,
})
.register("gptneox", "gpt-neo-x", "gpt_neo_x", {
name: "GPT-NeoX",
colorer: llmColorer,
})
.register("mpt", {
name: "MPT",
colorer: llmColorer,
})
.register("replit", {
name: "Replit",
colorer: llmColorer,
})
.register("starcoder", {
name: "StarCoder",
colorer: llmColorer,
})
.register("falcon", {
name: "Falcon",
colorer: llmColorer,
})
.register("qwen", {
name: "Qwen",
colorer: llmColorer,
})
.register("qwen2", {
name: "Qwen2",
colorer: llmColorer,
})
.register("stablelm", {
name: "StableLM",
colorer: llmColorer,
})
.register("mamba", {
name: "mamba",
colorer: llmColorer,
})
.register("command-r", {
name: "Command R",
colorer: llmColorer,
})
.register("gemma", {
name: "Gemma",
colorer: llmColorer,
})
.register("bert", {
name: "BERT",
colorer: embeddingColorer,
})
.register("nomic-bert", {
name: "Nomic BERT",
colorer: embeddingColorer,
})
.register("clip", {
name: "CLIP",
colorer: visionColorer,
});

View File

@ -1,8 +1,20 @@
import { type SimpleLogger } from "@lmstudio/lms-common";
import { LMStudioClient } from "@lmstudio/sdk";
import { getServerLastStatus } from "./subcommands/server";
export function createClient(logger: SimpleLogger) {
export async function createClient(logger: SimpleLogger) {
let port: number;
try {
const lastStatus = await getServerLastStatus(logger);
port = lastStatus.port;
} catch (e) {
logger.debug("Failed to get last server status", e);
port = 1234;
}
const baseUrl = `ws://127.0.0.1:${port}`;
logger.debug(`Connecting to server with baseUrl ${port}`);
return new LMStudioClient({
baseUrl,
logger,
});
}

View File

@ -1,5 +1,5 @@
import { run, subcommands } from "cmd-ts";
import { list } from "./subcommands/list";
import { ls } from "./subcommands/list";
import { start, status, stop } from "./subcommands/server";
import { printVersion, version } from "./subcommands/version";
@ -16,7 +16,7 @@ const cli = subcommands({
start,
status,
stop,
list,
ls,
},
});

View File

@ -1,7 +1,8 @@
import { type DownloadedModel } from "@lmstudio/sdk";
import chalk from "chalk";
import { command, subcommands } from "cmd-ts";
import { command, flag } from "cmd-ts";
import columnify from "columnify";
import { architectureInfoLookup } from "../architectureStylizations";
import { createClient } from "../createClient";
import { formatSizeBytesWithColor1000 } from "../formatSizeBytes1000";
import { createLogger, logLevelArgs } from "../logLevel";
@ -16,8 +17,12 @@ function loadedCheck(count: number) {
}
}
function coloredArch(arch?: string) {
return arch ?? "";
function architecture(architecture?: string) {
if (!architecture) {
return "";
}
const architectureInfo = architectureInfoLookup.find(architecture);
return architectureInfo.colorer(` ${architectureInfo.name} `);
}
function printDownloadedModelsTable(
@ -60,20 +65,22 @@ function printDownloadedModelsTable(
// Group is a model itself
const model = models[0];
return {
address: chalk.whiteBright(group),
address: chalk.grey(" ") + chalk.cyanBright(group),
sizeBytes: formatSizeBytesWithColor1000(model.sizeBytes),
arch: coloredArch(model.architecture),
arch: architecture(model.architecture),
loaded: loadedCheck(model.loadedIdentifiers.length),
};
}
return [
// Empty line between groups
{},
// Group title
{ address: chalk.whiteBright(group), sizeBytes: "", arch: "", loaded: "" },
// Models
{ address: chalk.grey(" ") + chalk.cyanBright(group) },
// Models within the group
...models.map(model => ({
address: chalk.black(". ") + chalk.gray("/" + model.remaining),
address: chalk.grey(" ") + chalk.white("/" + model.remaining),
sizeBytes: formatSizeBytesWithColor1000(model.sizeBytes),
arch: coloredArch(model.architecture),
arch: architecture(model.architecture),
loaded: loadedCheck(model.loadedIdentifiers.length),
})),
];
@ -87,18 +94,18 @@ function printDownloadedModelsTable(
columns: ["address", "sizeBytes", "arch", "loaded"],
config: {
address: {
headingTransform: () => chalk.cyanBright("ADDRESS"),
headingTransform: () => chalk.gray(" ") + chalk.greenBright("ADDRESS"),
},
sizeBytes: {
headingTransform: () => chalk.cyanBright("SIZE"),
headingTransform: () => chalk.greenBright("SIZE"),
align: "right",
},
arch: {
headingTransform: () => chalk.cyanBright("ARCHITECTURE"),
align: "left",
headingTransform: () => chalk.greenBright("ARCHITECTURE"),
align: "center",
},
loaded: {
headingTransform: () => chalk.cyanBright("LOADED"),
headingTransform: () => chalk.greenBright("LOADED"),
align: "left",
},
},
@ -108,40 +115,89 @@ function printDownloadedModelsTable(
);
}
const downloaded = command({
name: "downloaded",
description: "List downloaded models",
export const ls = command({
name: "ls",
description: "List all downloaded models",
args: {
...logLevelArgs,
llm: flag({
long: "llm",
description: "Show only LLM models",
}),
embedding: flag({
long: "embedding",
description: "Show only embedding models",
}),
json: flag({
long: "json",
description: "Outputs in JSON format to stdout",
}),
},
handler: async args => {
const logger = createLogger(args);
const client = createClient(logger);
const client = await createClient(logger);
const downloadedModels = await client.system.listDownloadedModels();
const { llm, embedding, json } = args;
let downloadedModels = await client.system.listDownloadedModels();
const loadedModels = await client.llm.listLoaded();
const originalModelsCount = downloadedModels.length;
if (llm || embedding) {
const allowedTypes = new Set<string>();
if (llm) {
allowedTypes.add("llm");
}
if (embedding) {
allowedTypes.add("embedding");
}
downloadedModels = downloadedModels.filter(model => allowedTypes.has(model.type));
}
const afterFilteringModelsCount = downloadedModels.length;
if (json) {
console.info(JSON.stringify(downloadedModels));
return;
}
if (afterFilteringModelsCount === 0) {
if (originalModelsCount === 0) {
console.info(chalk.redBright("You have not downloaded any models yet."));
} else {
console.info(
chalk.redBright(
`You have ${originalModelsCount} models, but none of them match the filter.`,
),
);
}
return;
}
console.info();
console.info();
const llms = downloadedModels.filter(model => model.type === "llm");
if (llms.length > 0) {
printDownloadedModelsTable(
chalk.bgGreenBright.black(" LLM ") + " " + chalk.green("(Large Language Models)"),
downloadedModels.filter(model => model.type === "llm"),
llms,
loadedModels,
);
console.info();
console.info();
}
const embeddings = downloadedModels.filter(model => model.type === "embedding");
if (embeddings.length > 0) {
printDownloadedModelsTable(
chalk.bgGreenBright.black(" Embeddings "),
downloadedModels.filter(model => model.type === "embedding"),
embeddings,
loadedModels,
);
console.info();
console.info();
}
},
});
export const list = subcommands({
name: "list",
description: "List models",
cmds: { downloaded },
});

View File

@ -156,7 +156,7 @@ async function checkHttpServerWithRetries(logger: SimpleLogger, port: number, ma
/**
* Gets the last status of the server.
*/
async function getServerLastStatus(logger: SimpleLogger) {
export async function getServerLastStatus(logger: SimpleLogger) {
const lastStatusPath = getServerLastStatusPath();
logger.debug(`Reading last status from ${lastStatusPath}`);
const lastStatus = JSON.parse(await readFile(lastStatusPath, "utf-8")) as HttpServerLastStatus;