diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9f05febc7..795f9346e 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -178,6 +178,8 @@ export abstract class Client extends EventEmitter<{'status': RoundStatus}>{ url.pathname += `tasks/${this.task.id}/model.json` const response = await fetch(url); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); + const encoded = new Uint8Array(await response.arrayBuffer()) return await serialization.model.decode(encoded) } diff --git a/discojs/src/task/task_handler.ts b/discojs/src/task/task_handler.ts index 7b38a6928..d2b692bb4 100644 --- a/discojs/src/task/task_handler.ts +++ b/discojs/src/task/task_handler.ts @@ -20,7 +20,7 @@ export async function pushTask( task: Task, model: Model, ): Promise { - await fetch(urlToTasks(base), { + const response = await fetch(urlToTasks(base), { method: "POST", body: JSON.stringify({ task, @@ -28,12 +28,14 @@ export async function pushTask( weights: await serialization.weights.encode(model.weights), }), }); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); } export async function fetchTasks( base: URL, ): Promise>> { const response = await fetch(urlToTasks(base)); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); const tasks: unknown = await response.json(); if (!Array.isArray(tasks)) { diff --git a/webapp/src/components/testing/__tests__/Testing.spec.ts b/webapp/src/components/testing/__tests__/Testing.spec.ts index 7cd5d62dd..341cfc129 100644 --- a/webapp/src/components/testing/__tests__/Testing.spec.ts +++ b/webapp/src/components/testing/__tests__/Testing.spec.ts @@ -60,7 +60,7 @@ it("shows stored models", async () => { it("allows to download server's models", async () => { vi.stubGlobal("fetch", async (url: string | URL) => { if (url.toString() === "http://localhost:8080/tasks") - return { json: () => Promise.resolve([TASK]) }; + return new Response(JSON.stringify([TASK])); throw new Error(`unhandled get: ${url}`); }); afterEach(() => { diff --git a/webapp/src/components/training/__tests__/Trainer.spec.ts b/webapp/src/components/training/__tests__/Trainer.spec.ts index b77d27493..5dfd74762 100644 --- a/webapp/src/components/training/__tests__/Trainer.spec.ts +++ b/webapp/src/components/training/__tests__/Trainer.spec.ts @@ -10,39 +10,15 @@ import { loadCSV } from "@epfml/discojs-web"; import Trainer from "../Trainer.vue"; import TrainingInformation from "../TrainingInformation.vue"; -vi.mock("axios", async () => { - async function get(url: string) { - if (url === "http://localhost:8080/tasks/titanic/model.json") { - return { - data: await serialization.model.encode( - await defaultTasks.titanic.getModel(), - ), - }; - } - throw new Error("unhandled get"); - } - - const axios = await vi.importActual("axios"); - return { - ...axios, - default: { - ...axios.default, - get, - }, - }; -}); - async function setupForTask() { const provider = defaultTasks.titanic; vi.stubGlobal("fetch", async (url: string | URL) => { - if (url.toString() === "http://localhost:8080/tasks/titanic/model.json") - return { - arrayBuffer: async () => { - const model = await provider.getModel(); - return await serialization.model.encode(model); - }, - }; + if (url.toString() === "http://localhost:8080/tasks/titanic/model.json") { + const model = await provider.getModel(); + const encoded = await serialization.model.encode(model); + return new Response(encoded); + } throw new Error(`unhandled get: ${url}`); }); afterEach(() => {