Skip to content

feat(hub): adding downloadFileToCacheDirWithProgress function #1334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { expect, test, describe, beforeAll, afterAll } from "vitest";
import { mkdtemp, rm } from "node:fs/promises";
import { join } from "node:path";
import { tmpdir } from "node:os";
import type { DownloadFileEvent} from "./download-file-to-cache-dir-with-progress";
import { downloadFileToCacheDirWithProgress } from "./download-file-to-cache-dir-with-progress";

describe('downloadFileToCacheDirWithProgress', () => {
let tempDir: string;
beforeAll(async () => {
tempDir = await mkdtemp(join(tmpdir(), 'model-'));
});

afterAll(() => {
return rm(tempDir, { recursive: true });
});

test('file should be downloaded with progress', async () => {
const iterator = downloadFileToCacheDirWithProgress({
repo: "ggml-org/models",
path: "bert-bge-small/ggml-model-f16-big-endian.gguf",
revision: "121397626a3ba7de07c154b4bbac3ac83f5628e0",
cacheDir: tempDir,
});

let res: IteratorResult<DownloadFileEvent, string>;

do {
res = await iterator.next();
if (!res.done) {
const { path } = res.value;
expect(path).toBe('bert-bge-small/ggml-model-f16-big-endian.gguf');
}
} while (!res.done);

expect(res.value).toStrictEqual(
join(tempDir, 'models--ggml-org--models', 'snapshots', '121397626a3ba7de07c154b4bbac3ac83f5628e0', 'bert-bge-small', 'ggml-model-f16-big-endian.gguf'),
);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import type {
DownloadFileToCacheDirParams} from "./download-file-to-cache-dir";
import {
prepareDownloadFileToCacheDir,
} from "./download-file-to-cache-dir";
import { downloadFile } from "./download-file";
import { rename } from "node:fs/promises";
import { createWriteStream } from "node:fs";
import { createSymlink } from "../utils/symlink";

export interface DownloadFileEvent {
event: 'file';
path: string;
progress: number;
}

export async function * downloadFileToCacheDirWithProgress(
params: DownloadFileToCacheDirParams
): AsyncGenerator<DownloadFileEvent, string> {
const { exists, pointerPath, blobPath } = await prepareDownloadFileToCacheDir(params);
if(exists) return pointerPath;

/**
* download with progress
*/
const incomplete = `${blobPath}.incomplete`;
console.debug(`Downloading ${params.path} to ${incomplete}`);

const response = await downloadFile(params);
if (!response || !response.ok || !response.body) {
throw new Error(`Invalid response for file ${params.path}`);
}

const contentLength = response.headers.get("Content-Length");
const totalSize = contentLength ? parseInt(contentLength, 10) : undefined;
const reader = response.body.getReader();

// Open a writable stream to the target file.
const fileStream = createWriteStream(incomplete);

let receivedSize = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
// Write the chunk immediately to the file.
fileStream.write(value);
receivedSize += value.length;
yield {
event: 'file',
path: params.path,
progress: totalSize ? receivedSize / totalSize : 0,
};
}

// Close the writable stream.
fileStream.end();

await rename(incomplete, blobPath);
await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });

return pointerPath;
}
174 changes: 104 additions & 70 deletions packages/hub/src/lib/download-file-to-cache-dir.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
import { dirname, join } from "node:path";
import { writeFile, rename, lstat, mkdir, stat } from "node:fs/promises";
import type { CommitInfo, PathInfo } from "./paths-info";
import { pathsInfo } from "./paths-info";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { toRepoId } from "../utils/toRepoId";
import { downloadFile } from "./download-file";
import { createSymlink } from "../utils/symlink";

export type DownloadFileToCacheDirParams = {
repo: RepoDesignation;
path: string;
/**
* If true, will download the raw git file.
*
* For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
*/
raw?: boolean;
/**
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
*
* @default "main"
*/
revision?: string;
hubUrl?: string;
cacheDir?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>

export const REGEX_COMMIT_HASH: RegExp = new RegExp("^[0-9a-f]{40}$");

function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
export function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
const snapshotPath = join(storageFolder, "snapshots");
return join(snapshotPath, revision, relativeFilename);
}
Expand All @@ -20,7 +42,7 @@ function getFilePointer(storageFolder: string, revision: string, relativeFilenam
* @param path
* @param followSymlinks
*/
async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
export async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
try {
if (followSymlinks) {
await stat(path);
Expand All @@ -33,35 +55,56 @@ async function exists(path: string, followSymlinks?: boolean): Promise<boolean>
}
}

/**
* Download a given file if it's not already present in the local cache.
* @param params
* @return the symlink to the blob object
*/
export async function downloadFileToCacheDir(
params: {
repo: RepoDesignation;
path: string;
/**
* If true, will download the raw git file.
*
* For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
*/
raw?: boolean;
/**
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
*
* @default "main"
*/
revision?: string;
hubUrl?: string;
cacheDir?: string;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<string> {
export async function preparePaths(params: DownloadFileToCacheDirParams, storageFolder: string): Promise<{ pointerPath: string; blobPath: string; etag: string }> {
const pathsInformation = await pathsInfo({
...params,
paths: [params.path],
revision: params.revision ?? "main",
expand: true,
});

if (!pathsInformation || pathsInformation.length !== 1) {
throw new Error(`cannot get path info for ${params.path}`);
}

const pathInfo = pathsInformation[0];
const etag = pathInfo.lfs ? pathInfo.lfs.oid : pathInfo.oid;
const pointerPath = getFilePointer(storageFolder, pathInfo.lastCommit.id, params.path);
const blobPath = join(storageFolder, "blobs", etag);

return { pointerPath, blobPath, etag };
}

export async function ensureDirectories(blobPath: string, pointerPath: string): Promise<void> {
await mkdir(dirname(blobPath), { recursive: true });
await mkdir(dirname(pointerPath), { recursive: true });
}

export async function downloadAndStoreFile(params: DownloadFileToCacheDirParams, blobPath: string): Promise<void> {
const incomplete = `${blobPath}.incomplete`;
console.debug(`Downloading ${params.path} to ${incomplete}`);

const response = await downloadFile(params);
if (!response || !response.ok || !response.body) {
throw new Error(`Invalid response for file ${params.path}`);
}

// @ts-expect-error resp.body is a Stream, but Stream in internal to node
await writeFile(incomplete, response.body);
await rename(incomplete, blobPath);
}

export type PrepareDownloadFileToCacheDirResult = {
exists: true,
pointerPath: string
blobPath?: undefined,
} | {
exists: false,
pointerPath: string;
blobPath: string;
}

export async function prepareDownloadFileToCacheDir(params: DownloadFileToCacheDirParams): Promise<PrepareDownloadFileToCacheDirResult> {
// get revision provided or default to main
const revision = params.revision ?? "main";
const cacheDir = params.cacheDir ?? getHFHubCachePath();
Expand All @@ -70,64 +113,55 @@ export async function downloadFileToCacheDir(
// get storage folder
const storageFolder = join(cacheDir, getRepoFolderName(repoId));

let commitHash: string | undefined;

// if user provides a commitHash as revision, and they already have the file on disk, shortcut everything.
if (REGEX_COMMIT_HASH.test(revision)) {
commitHash = revision;
const pointerPath = getFilePointer(storageFolder, revision, params.path);
if (await exists(pointerPath, true)) return pointerPath;
}

const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({
...params,
paths: [params.path],
revision: revision,
expand: true,
});
if (!pathsInformation || pathsInformation.length !== 1) throw new Error(`cannot get path info for ${params.path}`);

let etag: string;
if (pathsInformation[0].lfs) {
etag = pathsInformation[0].lfs.oid; // get the LFS pointed file oid
} else {
etag = pathsInformation[0].oid; // get the repo file if not a LFS pointer
if (await exists(pointerPath, true)) return {
exists: true,
pointerPath: pointerPath,
};
}

const pointerPath = getFilePointer(storageFolder, commitHash ?? pathsInformation[0].lastCommit.id, params.path);
const blobPath = join(storageFolder, "blobs", etag);
const { pointerPath, blobPath } = await preparePaths(params, storageFolder);

// if we have the pointer file, we can shortcut the download
if (await exists(pointerPath, true)) return pointerPath;
if (await exists(pointerPath, true)) return {
exists: true,
pointerPath: pointerPath,
};

// mkdir blob and pointer path parent directory
await mkdir(dirname(blobPath), { recursive: true });
await mkdir(dirname(pointerPath), { recursive: true });
await ensureDirectories(blobPath, pointerPath);

// We might already have the blob but not the pointer
// shortcut the download if needed
if (await exists(blobPath)) {
// create symlinks in snapshot folder to blob object
await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });
return pointerPath;
return { exists: true, pointerPath, }
}

const incomplete = `${blobPath}.incomplete`;
console.debug(`Downloading ${params.path} to ${incomplete}`);

const response: Response | null = await downloadFile({
...params,
revision: commitHash,
});
return {
exists: false,
pointerPath: pointerPath,
blobPath: blobPath,
}
}

if (!response || !response.ok || !response.body) throw new Error(`invalid response for file ${params.path}`);
/**
* Download a given file if it's not already present in the local cache.
* @param params
* @return the symlink to the blob object
*/
export async function downloadFileToCacheDir(
params: DownloadFileToCacheDirParams
): Promise<string> {
const { exists, pointerPath, blobPath } = await prepareDownloadFileToCacheDir(params);
if(exists) return pointerPath;

// @ts-expect-error resp.body is a Stream, but Stream in internal to node
await writeFile(incomplete, response.body);
// download the file if we don't have it
await downloadAndStoreFile(params, blobPath);

// rename .incomplete file to expect blob
await rename(incomplete, blobPath);
// create symlinks in snapshot folder to blob object
await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });
return pointerPath;
}