Initial work for lms pull (#108)

* Downloading artifacts

* Fix documentation
This commit is contained in:
ryan-the-crayon
2024-11-26 18:43:07 -05:00
committed by GitHub
parent bca35e0cce
commit d23b063b0b
5 changed files with 213 additions and 37 deletions

48
src/downloadPbUpdater.ts Normal file
View File

@ -0,0 +1,48 @@
import { text } from "@lmstudio/lms-common";
import { type DownloadProgressUpdate } from "@lmstudio/sdk";
import { formatSizeBytes1000 } from "./formatSizeBytes1000.js";
import { type ProgressBar } from "./ProgressBar.js";
function formatRemainingTime(timeSeconds: number) {
const seconds = timeSeconds % 60;
const minutes = Math.floor(timeSeconds / 60) % 60;
const hours = Math.floor(timeSeconds / 3600);
if (hours > 0) {
return `${String(hours).padStart(2, "0")}:${String(minutes).padStart(2, "0")}:${String(seconds).padStart(2, "0")}`;
}
return `${String(minutes).padStart(2, "0")}:${String(seconds).padStart(2, "0")}`;
}
/**
* Given a progress bar pb, return a function that updates the progress bar with the given
* DownloadProgressUpdate.
*/
export function createDownloadPbUpdater(pb: ProgressBar) {
let longestDownloadedBytesStringLength = 6;
let longestTotalBytesStringLength = 6;
let longestSpeedBytesPerSecondStringLength = 6;
return ({ downloadedBytes, totalBytes, speedBytesPerSecond }: DownloadProgressUpdate) => {
const downloadedBytesString = formatSizeBytes1000(downloadedBytes);
if (downloadedBytesString.length > longestDownloadedBytesStringLength) {
longestDownloadedBytesStringLength = downloadedBytesString.length;
}
const totalBytesString = formatSizeBytes1000(totalBytes);
if (totalBytesString.length > longestTotalBytesStringLength) {
longestTotalBytesStringLength = totalBytesString.length;
}
const speedBytesPerSecondString = formatSizeBytes1000(speedBytesPerSecond);
if (speedBytesPerSecondString.length > longestSpeedBytesPerSecondStringLength) {
longestSpeedBytesPerSecondStringLength = speedBytesPerSecondString.length;
}
const timeLeftSeconds = Math.round((totalBytes - downloadedBytes) / speedBytesPerSecond);
pb.setRatio(
downloadedBytes / totalBytes,
text`
${downloadedBytesString.padStart(longestDownloadedBytesStringLength)} /
${totalBytesString.padStart(longestTotalBytesStringLength)} |
${speedBytesPerSecondString.padStart(longestSpeedBytesPerSecondStringLength)}/s | ETA
${formatRemainingTime(timeLeftSeconds)}
`,
);
};
}

View File

@ -31,6 +31,7 @@ const cli = subcommands({
log,
// dev,
// push,
// pull,
import: importCmd,
bootstrap,
version,
@ -38,6 +39,6 @@ const cli = subcommands({
});
run(cli, process.argv.slice(2)).catch(error => {
console.error(error?.message ?? error);
console.error(error?.stack ?? error);
process.exit(1);
});

69
src/optionalPositional.ts Normal file
View File

@ -0,0 +1,69 @@
import { type Type } from "cmd-ts";
import {
type ArgParser,
type ParseContext,
type ParsingResult,
} from "cmd-ts/dist/cjs/argparser.js";
import { type OutputOf } from "cmd-ts/dist/cjs/from.js";
import { type Descriptive, type Displayed, type ProvidesHelp } from "cmd-ts/dist/cjs/helpdoc.js";
import { type PositionalArgument } from "cmd-ts/dist/cjs/newparser/parser.js";
import * as Result from "cmd-ts/dist/cjs/Result.js";
import { type HasType } from "cmd-ts/dist/cjs/type.js";
export type OptionalPositionalsConfig<Decoder extends Type<string, any>> = HasType<Decoder> &
Partial<Displayed & Descriptive> & {
default: OutputOf<Decoder>;
};
function optionalPositionalImpl<Decoder extends Type<string, any>>(
config: OptionalPositionalsConfig<Decoder>,
): ArgParser<OutputOf<Decoder>> & ProvidesHelp {
return {
helpTopics() {
const displayName = config.displayName ?? config.type.displayName ?? "arg";
return [
{
usage: `[${displayName}]`,
category: "arguments",
defaults: [],
description: config.description ?? config.type.description ?? "",
},
];
},
register(_opts) {},
async parse({ nodes, visitedNodes }: ParseContext): Promise<ParsingResult<OutputOf<Decoder>>> {
const positionals = nodes.filter(
(node): node is PositionalArgument =>
node.type === "positionalArgument" && !visitedNodes.has(node),
);
if (positionals.length === 0) {
return Result.ok(config.default);
}
visitedNodes.add(positionals[0]);
const decoded = await Result.safeAsync(config.type.from(positionals[0].raw));
if (Result.isOk(decoded)) {
return Result.ok(decoded.value);
} else {
return Result.err({
errors: [
{
nodes: [positionals[0]],
message: decoded.error.message,
},
],
});
}
},
};
}
type OptionalPositionalsParser<Decoder extends Type<string, any>> = ArgParser<OutputOf<Decoder>> &
ProvidesHelp;
export function optionalPositional<Decoder extends Type<string, any>>(
config: OptionalPositionalsConfig<Decoder>,
): OptionalPositionalsParser<Decoder> {
return optionalPositionalImpl(config);
}

View File

@ -7,21 +7,12 @@ import { boolean, command, flag, option, optional, positional, string } from "cm
import inquirer from "inquirer";
import { askQuestion } from "../confirm.js";
import { createClient, createClientArgs } from "../createClient.js";
import { createDownloadPbUpdater } from "../downloadPbUpdater.js";
import { formatSizeBytes1000 } from "../formatSizeBytes1000.js";
import { createLogger, logLevelArgs } from "../logLevel.js";
import { ProgressBar } from "../ProgressBar.js";
import { refinedNumber } from "../types/refinedNumber.js";
function formatRemainingTime(timeSeconds: number) {
const seconds = timeSeconds % 60;
const minutes = Math.floor(timeSeconds / 60) % 60;
const hours = Math.floor(timeSeconds / 3600);
if (hours > 0) {
return `${String(hours).padStart(2, "0")}:${String(minutes).padStart(2, "0")}:${String(seconds).padStart(2, "0")}`;
}
return `${String(minutes).padStart(2, "0")}:${String(seconds).padStart(2, "0")}`;
}
export const get = command({
name: "get",
description: "Searching and downloading a model from online.",
@ -261,6 +252,7 @@ export const get = command({
let isAskingExitingBehavior = false;
let canceled = false;
const pb = new ProgressBar(0, "", 22);
const updatePb = createDownloadPbUpdater(pb);
const abortController = new AbortController();
const sigintListener = () => {
process.removeListener("SIGINT", sigintListener);
@ -285,39 +277,15 @@ export const get = command({
});
};
process.addListener("SIGINT", sigintListener);
let longestDownloadedBytesStringLength = 6;
let longestTotalBytesStringLength = 6;
let longestSpeedBytesPerSecondStringLength = 6;
try {
let alreadyExisted = true;
const defaultIdentifier = await option.download({
onProgress: ({ downloadedBytes, totalBytes, speedBytesPerSecond }) => {
onProgress: update => {
alreadyExisted = false;
if (isAskingExitingBehavior) {
return;
}
const downloadedBytesString = formatSizeBytes1000(downloadedBytes);
if (downloadedBytesString.length > longestDownloadedBytesStringLength) {
longestDownloadedBytesStringLength = downloadedBytesString.length;
}
const totalBytesString = formatSizeBytes1000(totalBytes);
if (totalBytesString.length > longestTotalBytesStringLength) {
longestTotalBytesStringLength = totalBytesString.length;
}
const speedBytesPerSecondString = formatSizeBytes1000(speedBytesPerSecond);
if (speedBytesPerSecondString.length > longestSpeedBytesPerSecondStringLength) {
longestSpeedBytesPerSecondStringLength = speedBytesPerSecondString.length;
}
const timeLeftSeconds = Math.round((totalBytes - downloadedBytes) / speedBytesPerSecond);
pb.setRatio(
downloadedBytes / totalBytes,
text`
${downloadedBytesString.padStart(longestDownloadedBytesStringLength)} /
${totalBytesString.padStart(longestTotalBytesStringLength)} |
${speedBytesPerSecondString.padStart(longestSpeedBytesPerSecondStringLength)}/s | ETA
${formatRemainingTime(timeLeftSeconds)}
`,
);
updatePb(update);
},
onStartFinalizing: () => {
alreadyExisted = false;

90
src/subcommands/pull.ts Normal file
View File

@ -0,0 +1,90 @@
import { text } from "@lmstudio/lms-common";
import { kebabCaseRegex } from "@lmstudio/lms-shared-types";
import { command, positional, string, type Type } from "cmd-ts";
import { resolve } from "path";
import { createClient, createClientArgs } from "../createClient.js";
import { createDownloadPbUpdater } from "../downloadPbUpdater.js";
import { exists } from "../exists.js";
import { createLogger, logLevelArgs } from "../logLevel.js";
import { optionalPositional } from "../optionalPositional.js";
import { ProgressBar } from "../ProgressBar.js";
const artifactIdentifierType: Type<string, { owner: string; name: string }> = {
async from(str) {
str = str.trim().toLowerCase();
const parts = str.split("/");
if (parts.length !== 2) {
throw new Error("Invalid artifact identifier. Must be in the form of 'owner/name'.");
}
const [owner, name] = parts;
if (!kebabCaseRegex.test(owner)) {
throw new Error("Invalid owner. Must be kebab-case.");
}
if (!kebabCaseRegex.test(name)) {
throw new Error("Invalid name. Must be kebab-case.");
}
return { owner, name };
},
};
export const pull = command({
name: "pull",
description: "Pull an artifact from LM Studio Hub to a local folder.",
args: {
artifactIdentifier: positional({
displayName: "artifact-identifier",
description: "The identifier for the artifact. Must be in the form of 'owner/name'.",
type: artifactIdentifierType,
}),
path: optionalPositional({
displayName: "path",
description: text`
The path to the folder to pull the resources into. If not provided, defaults to a new folder
with the artifact name in the current working directory.
`,
type: string,
default: "",
}),
...logLevelArgs,
...createClientArgs,
},
handler: async args => {
const logger = createLogger(args);
const client = await createClient(logger, args);
const { owner, name } = args.artifactIdentifier;
let path = args.path;
let autoNamed: boolean;
if (path === "") {
path = resolve(`./${name}`);
autoNamed = true;
logger.debug(`Path not provided. Using default: ${path}`);
} else {
path = resolve(path);
autoNamed = false;
logger.debug(`Using provided path: ${path}`);
}
if (await exists(path)) {
logger.error(`Path already exists: ${path}`);
if (autoNamed) {
logger.error("You can provide a different path by providing it as a second argument.");
}
process.exit(1);
}
const pb = new ProgressBar(0, "", 22);
const updatePb = createDownloadPbUpdater(pb);
await client.repository.downloadArtifact({
owner,
name,
revisionNumber: -1, // -1 means the latest revision.
path,
onProgress: update => {
updatePb(update);
},
onStartFinalizing: () => {
pb.stop();
logger.info("Finalizing download...");
},
});
logger.info(`Artifact successfully pulled to ${path}.`);
},
});