Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a NextJS frontend for training GPT on wikitext and executing some web text dataset related tests #620

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions browser/client/.eslintrc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"extends": "next/core-web-vitals"
}
36 changes: 36 additions & 0 deletions browser/client/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.

# dependencies
/node_modules
/.pnp
.pnp.js
.yarn/install-state.gz

# testing
/coverage

# next.js
/.next/
/out/

# production
/build

# misc
.DS_Store
*.pem

# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# local env files
.env*.local

# vercel
.vercel

# typescript
*.tsbuildinfo
next-env.d.ts
36 changes: 36 additions & 0 deletions browser/client/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app).

## Getting Started

First, run the development server:

```bash
npm run dev
# or
yarn dev
# or
pnpm dev
# or
bun dev
```

Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.

You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.

This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font.

## Learn More

To learn more about Next.js, take a look at the following resources:

- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.

You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome!

## Deploy on Vercel

The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.

Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details.
38 changes: 38 additions & 0 deletions browser/client/app/api/files/route.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import fs, { Stats } from 'fs'
import { NextRequest, NextResponse } from 'next/server'
import path from 'path'
import { ReadableOptions } from 'stream'

function streamFile(path: string, options?: ReadableOptions): ReadableStream<Uint8Array> {
const stream = fs.createReadStream(path, options)

return new ReadableStream({
start(controller) {
stream.on('data', (chunk: Buffer) => controller.enqueue(new Uint8Array(chunk)))
stream.on('end', () => controller.close())
stream.on('error', (error: NodeJS.ErrnoException) => controller.error(error))
},
cancel() {
stream.destroy()
},
})
}

export async function GET(req: NextRequest): Promise<NextResponse> {
const type = req.nextUrl.searchParams.get('type') as keyof typeof files
if (!type) return new NextResponse('Missing type parameter', { status: 400 })

const file = files[type][0]
const stats: Stats = await fs.promises.stat(file)
const data: ReadableStream<Uint8Array> = streamFile(file) // Stream the file with a 1kb chunk
const res = new NextResponse(data, {
status: 200,
headers: new Headers({
'content-disposition': `attachment; filename=${path.basename(file)}`,
'content-type': 'application/octet-stream',
'content-length': stats.size + '',
}),
})

return res
}
Binary file added browser/client/app/favicon.ico
Binary file not shown.
81 changes: 81 additions & 0 deletions browser/client/app/globals.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
@tailwind base;
@tailwind components;
@tailwind utilities;

@layer base {
:root {
--background: 0 0% 100%;
--foreground: 222.2 84% 4.9%;

--card: 0 0% 100%;
--card-foreground: 222.2 84% 4.9%;

--popover: 0 0% 100%;
--popover-foreground: 222.2 84% 4.9%;

--primary: 222.2 47.4% 11.2%;
--primary-foreground: 210 40% 98%;

--secondary: 210 40% 96.1%;
--secondary-foreground: 222.2 47.4% 11.2%;

--muted: 210 40% 96.1%;
--muted-foreground: 215.4 16.3% 46.9%;

--accent: 210 40% 96.1%;
--accent-foreground: 222.2 47.4% 11.2%;

--destructive: 0 84.2% 60.2%;
--destructive-foreground: 210 40% 98%;

--border: 214.3 31.8% 91.4%;
--input: 214.3 31.8% 91.4%;
--ring: 222.2 84% 4.9%;

--radius: 0.5rem;
}

.dark {
--background: 222.2 84% 4.9%;
--foreground: 210 40% 98%;

--card: 222.2 84% 4.9%;
--card-foreground: 210 40% 98%;

--popover: 222.2 84% 4.9%;
--popover-foreground: 210 40% 98%;

--primary: 210 40% 98%;
--primary-foreground: 222.2 47.4% 11.2%;

--secondary: 217.2 32.6% 17.5%;
--secondary-foreground: 210 40% 98%;

--muted: 217.2 32.6% 17.5%;
--muted-foreground: 215 20.2% 65.1%;

--accent: 217.2 32.6% 17.5%;
--accent-foreground: 210 40% 98%;

--destructive: 0 62.8% 30.6%;
--destructive-foreground: 210 40% 98%;

--border: 217.2 50.6% 25.5%;
--input: 217.2 32.6% 17.5%;
--ring: 212.7 26.8% 83.9%;
}
}

@layer base {
* {
@apply border-border;
}
body {
@apply bg-background text-foreground;
}
}

body {
color: hsl(var(--foreground));
background: hsl(var(--background));
}
22 changes: 22 additions & 0 deletions browser/client/app/layout.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import type { Metadata } from 'next'
import { Inter } from 'next/font/google'
import './globals.css'

const inter = Inter({ subsets: ['latin'] })

export const metadata: Metadata = {
title: 'Create Next App',
description: 'Generated by create next app',
}

export default function RootLayout({
children,
}: {
children: React.ReactNode
}) {
return (
<html lang="en">
<body className={`${inter.className} dark`}>{children}</body>
</html>
)
}
166 changes: 166 additions & 0 deletions browser/client/app/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
'use client'
import path from 'path'
import { useEffect, useState } from 'react'
import { tf, dataset, defaultTasks, Task, browser } from '@epfml/discojs-web'
import { main } from '@/disco/main'
import { models } from '@epfml/discojs-web'
import { Tabs, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { Separator } from '@/components/ui/separator'

const { gpt } = models

const task = defaultTasks.wikitext.getTask()
const config = task.trainingInformation.modelConfig

// TODO: source: TextSource should be loaded using some generic script using the task definition
// a script to search for dataset files corresponding to a task is defined under experiment/data.ts (getDatasetSource)
// this script should be used here as well (see TODO comment in that file)
const getSource = (datasetName: string): dataset.TextSource => {
const datasetsFolder = path.join(
'../../experiment',
'datasets',
datasetName
)
const source: dataset.TextSource = {
train: [path.join(datasetsFolder, 'train.tokens')],
validation: undefined,
}
// wikitext has a validation file but tiny-shakespeare doesnt
// this is one of the reason why the todo above is relevant
if (datasetName === 'wikitext-103') {
source.validation = [path.join(datasetsFolder, 'validation.tokens')]
}

return source
}

const getDatasplit = async (task: Task, datasetName: string) => {
const source = getSource(datasetName)
return await new browser.dataset.loader.WebTextLoader(task).loadAll(
source,
config
)
}

const runDatasetBenchmark = async (datasplit: dataset.DataSplit) => {
const ds = datasplit.train.dataset as dataset.TokenizedDataset
const iter = await ds.iterator()
const iterations = 1000
const label = `Benchmark ${iterations} iterations`
const { blockSize, batchSize, vocabSize } = config
console.log(label, 'starts', { blockSize, batchSize, vocabSize })
console.time(label)
for (let i = 0; i < iterations; i++) {
await iter.next()
}
console.timeEnd(label)
}

const DATASET_NAMES = ['wikitext-103', 'tiny-shakespeare'] as const
type DatasetName = (typeof DATASET_NAMES)[number]

export default function Home() {
const [datasetName, setDatasetName] = useState<DatasetName>('wikitext-103')
const [config, setConfig] = useState<models.GPTConfigWithWandb>()
const [availableBackends, setAvailableBackends] = useState<string[]>([])
const [backendName, setBackendName] = useState<string>('cpu')

useEffect(() => {
setConfig(gpt.getConfig(task.trainingInformation.modelConfig))
setAvailableBackends(tf.engine().backendNames())
setBackend(tf.getBackend())
}, [])

const datasetBenchmark = async () => {
const datasplit = await getDatasplit(task, datasetName)
await runDatasetBenchmark(datasplit)
}

const startTraining = async () => {
// FIXME: url in .env (or fetched from backend?)
const datasplit = await getDatasplit(task as Task, datasetName)
const url = new URL('', 'http://localhost:8000')
await main(url, task, datasplit).catch(console.error)
}

// util function to properly set the backend
// TODO: Move this to core as well?
const setBackend = (backendName: string) => async () => {
await tf.setBackend(backendName)
await tf.ready()

const tfBackend = tf.getBackend()
if (tfBackend !== backendName) {
throw new Error('backend not properly set, got: ' + tfBackend)
}

console.log('Backend set to:', tfBackend)
setBackendName(tfBackend)
}

console.log(backendName, datasetName)

return (
<main className="flex p-24 gap-8">
<pre className="bg-slate-800 rounded p-4 max-w-min">
{JSON.stringify(config, undefined, 4)}
</pre>
<div className="flex flex-col gap-8">
<div className="flex justify-between items-center gap-4 bg-slate-800 rounded py-2 px-8 h-fit">
Backend:
<Tabs value={backendName} onValueChange={setBackend}>
<TabsList className="gap-4">
{availableBackends.map((backendName, i) => (
<TabsTrigger
className="hover:!bg-slate-900"
value={backendName}
key={`btn-${i}`}
>
{backendName}
</TabsTrigger>
))}
</TabsList>
</Tabs>
</div>
<div className="flex justify-between items-center gap-4 bg-slate-800 rounded py-2 px-8 h-fit">
Dataset:
<Tabs
value={datasetName}
onValueChange={(v) => setDatasetName(v as DatasetName)}
>
<TabsList className="gap-2">
<TabsTrigger
value="wikitext-103"
className="hover:!bg-slate-900"
>
wikitext-103
</TabsTrigger>
<TabsTrigger
value="tiny-shakespeare"
className="hover:!bg-slate-900"
>
tiny-shakespeare
</TabsTrigger>
</TabsList>
</Tabs>
<Separator orientation="vertical" />
<button
onClick={datasetBenchmark}
className="bg-background rounded px-3 py-1.5 hover:bg-slate-900 text-sm font-medium"
>
benchmark
</button>
</div>
<div className="flex justify-between items-center gap-4 bg-slate-800 rounded py-2 px-8 h-fit">
Training:
<button
onClick={startTraining}
className="bg-background rounded px-3 py-1.5 hover:bg-slate-900 text-sm font-medium"
>
run
</button>
</div>
</div>
</main>
)
}
Binary file added browser/client/bun.lockb
Binary file not shown.
1 change: 1 addition & 0 deletions browser/client/chrome-webgpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
google-chrome --enable-unsafe-webgpu --enable-features=Vulkan,UseSkiaRenderer &
Loading
Loading