Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Layoric committed Nov 25, 2024
2 parents 8d0de2e + 22edcaa commit 382c782
Show file tree
Hide file tree
Showing 29 changed files with 410 additions and 465 deletions.
8 changes: 4 additions & 4 deletions AiServer.ServiceInterface/AudioServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace AiServer.ServiceInterface;

public class AudioServices(IBackgroundJobs jobs) : Service
{
public async Task<object> Any(ConvertAudio request)
public async Task<ArtifactGenerationResponse> Any(ConvertAudio request)
{
if (Request?.Files == null || Request.Files.Length == 0)
{
Expand All @@ -27,9 +27,9 @@ public async Task<object> Any(ConvertAudio request)
};

var transformService = base.ResolveService<MediaTransformProviderServices>();
return await transformRequest.ProcessTransform(jobs, transformService, sync: true);
return await transformRequest.ProcessSyncTransformAsync(jobs, transformService);
}
public async Task<object> Any(QueueConvertAudio request)
public async Task<QueueMediaTransformResponse> Any(QueueConvertAudio request)
{
if (Request?.Files == null || Request.Files.Length == 0)
{
Expand All @@ -52,7 +52,7 @@ public async Task<object> Any(QueueConvertAudio request)
};

var transformService = base.ResolveService<MediaTransformProviderServices>();
return await transformRequest.ProcessTransform(jobs, transformService);
return await transformRequest.ProcessQueuedTransformAsync(jobs, transformService);
}

private bool IsAudioFormat(MediaOutputFormat outputformat)
Expand Down
2 changes: 0 additions & 2 deletions AiServer.ServiceInterface/ComfyApiServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ public class ComfyApiServices(AppData appData) : Service
{
public async Task<object> Any(GetComfyModels request)
{

try
{
var comfyClient = new ComfyClient(request.ApiBaseUrl!, request.ApiKey);
Expand Down Expand Up @@ -42,4 +41,3 @@ public async Task<object> Any(GetComfyModelMappings request)
};
}
}

17 changes: 17 additions & 0 deletions AiServer.ServiceInterface/DtoExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using AiServer.ServiceModel;

namespace AiServer.ServiceInterface;

public static class DtoExtensions
{
public static TextGenerationResponse ToTextGenerationResponse(this GenerationResponse response) => response.TextOutputs?.Count > 0 ? new() {
Results = response.TextOutputs,
ResponseStatus = response.ResponseStatus
} : throw new Exception("Failed to generate any text outputs");

public static ArtifactGenerationResponse ToArtifactGenerationResponse(this GenerationResponse response) => response.Outputs?.Count > 0 ? new() {
Results = response.Outputs,
ResponseStatus = response.ResponseStatus
} : throw new Exception("Failed to generate any outputs");

}
150 changes: 93 additions & 57 deletions AiServer.ServiceInterface/GenerationServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,64 @@ namespace AiServer.ServiceInterface;

public class GenerationServices(IBackgroundJobs jobs, AppData appData) : Service
{
public async Task<object> Any(GetArtifactGenerationStatus request)
public GetTextGenerationStatusResponse Any(GetTextGenerationStatus request)
{
JobResult? job = null;
if (request.JobId != null)
{
job = jobs.GetJob((long)request.JobId);
}
var job = (request.JobId != null
? jobs.GetJob((long)request.JobId)
: null)
?? (!string.IsNullOrEmpty(request.RefId)
? jobs.GetJobByRefId(request.RefId)
: null);

if (!string.IsNullOrEmpty(request.RefId))
{
job = jobs.GetJobByRefId(request.RefId);
}

if(job == null || job.Job == null || job.Summary.RefId == null)
if (job?.Job == null || job.Summary.RefId == null)
throw HttpError.NotFound("Job not found");
if (job.Failed != null)
throw new Exception($"Job failed: {job.Failed.Error}");

// We know at this point, we definitely have a job
JobResult queuedJob = job;

var completedResponse = new GetArtifactGenerationStatusResponse
var ret = new GetTextGenerationStatusResponse
{
RefId = queuedJob.Job?.RefId ?? queuedJob.Summary.RefId,
JobId = queuedJob.Job?.Id ?? queuedJob.Summary.Id,
Status = queuedJob.Job?.Status ?? queuedJob.Job!.State.ToString(),
JobState = queuedJob.Job?.State ?? queuedJob.Summary.State
RefId = job.Job?.RefId ?? job.Summary.RefId,
JobId = job.Job?.Id ?? job.Summary.Id,
Status = job.Job?.Status ?? job.Job!.State.ToString(),
JobState = job.Job?.State ?? job.Summary.State
};

// Handle failed jobs
if (queuedJob.Failed != null)
{
throw new Exception($"Job failed: {queuedJob.Failed.Error}");
}
if ((job.Job?.State ?? job.Summary.State) != BackgroundJobState.Completed)
return ret;

if ((queuedJob.Job?.State ?? queuedJob.Summary.State) != BackgroundJobState.Completed)
return completedResponse;
var outputs = job.GetOutputs();
ret.Results = outputs.Item2; // Get TextOutputs
return ret;
}

public GetArtifactGenerationStatusResponse Any(GetArtifactGenerationStatus request)
{
var job = (request.JobId != null
? jobs.GetJob((long)request.JobId)
: null)
?? (!string.IsNullOrEmpty(request.RefId)
? jobs.GetJobByRefId(request.RefId)
: null);

if (job?.Job == null || job.Summary.RefId == null)
throw HttpError.NotFound("Job not found");
if (job.Failed != null)
throw new Exception($"Job failed: {job.Failed.Error}");

// Process successful job results
var outputs = queuedJob.GetOutputs();
completedResponse.Results = outputs.Item1;
var ret = new GetArtifactGenerationStatusResponse
{
RefId = job.Job?.RefId ?? job.Summary.RefId,
JobId = job.Job?.Id ?? job.Summary.Id,
Status = job.Job?.Status ?? job.Job!.State.ToString(),
JobState = job.Job?.State ?? job.Summary.State
};

return completedResponse;
if ((job.Job?.State ?? job.Summary.State) != BackgroundJobState.Completed)
return ret;

var outputs = job.GetOutputs();
ret.Results = outputs.Item1; // Get ArtifactOutputs
return ret;
}

public object Any(ActiveMediaModels request)
Expand All @@ -70,7 +87,7 @@ public object Any(ActiveMediaModels request)
};
}

public async Task<object> Any(TextToImage request)
public async Task<ArtifactGenerationResponse> Any(TextToImage request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -88,11 +105,11 @@ public async Task<object> Any(TextToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageToImage request)
public async Task<ArtifactGenerationResponse> Any(ImageToImage request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -105,16 +122,15 @@ public async Task<object> Any(ImageToImage request)
NegativePrompt = request.NegativePrompt,
Denoise = request.Denoise,
BatchSize = request.BatchSize,
ImageInput = request.Image
}
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageUpscale request)
public async Task<ArtifactGenerationResponse> Any(ImageUpscale request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -123,16 +139,15 @@ public async Task<object> Any(ImageUpscale request)
Model = "image-upscale-2x",
Seed = request.Seed,
TaskType = AiTaskType.ImageUpscale,
ImageInput = request.Image
}
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageWithMask request)
public async Task<ArtifactGenerationResponse> Any(ImageWithMask request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -144,38 +159,34 @@ public async Task<object> Any(ImageWithMask request)
PositivePrompt = request.PositivePrompt,
NegativePrompt = request.NegativePrompt,
Denoise = request.Denoise,
ImageInput = request.Image,
MaskInput = request.Mask
}
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageToText request)
public async Task<TextGenerationResponse> Any(ImageToText request)
{
var diffRequest = new CreateGeneration
{
Request = new()
{
Model = "image-to-text",
TaskType = AiTaskType.ImageToText,
ImageInput = request.Image
}
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<TextGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToTextGenerationResponse();
}
}

public static class GenerationServiceExtensions
{
public static async Task<object> ProcessGeneration(this CreateGeneration diffRequest, IBackgroundJobs jobs,
MediaProviderServices genProviderServices, bool sync = false)
public static async Task<QueueGenerationResponse> ProcessQueuedGenerationAsync(this CreateGeneration diffRequest, IBackgroundJobs jobs, MediaProviderServices genProviderServices)
{
CreateGenerationResponse? diffResponse = null;
try
Expand All @@ -189,7 +200,7 @@ public static async Task<object> ProcessGeneration(this CreateGeneration diffReq
throw;
}

if(diffResponse == null)
if (diffResponse == null)
throw new Exception("Failed to start generation");

var job = jobs.GetJob(diffResponse.Id);
Expand Down Expand Up @@ -222,12 +233,37 @@ public static async Task<object> ProcessGeneration(this CreateGeneration diffReq
throw new Exception($"Job failed: {job.Failed.Error}");
}

// If not a synchronous request, return immediately with job details
if (sync != true)
return queueResponse;
}

public static async Task<GenerationResponse> ProcessSyncGenerationAsync(this CreateGeneration diffRequest, IBackgroundJobs jobs, MediaProviderServices genProviderServices)
{
CreateGenerationResponse? diffResponse = null;
try
{
var response = genProviderServices.Any(diffRequest);
diffResponse = response as CreateGenerationResponse;
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}

if (diffResponse == null)
throw new Exception("Failed to start generation");

var job = jobs.GetJob(diffResponse.Id);
// For synchronous requests, wait for the job to be created
while (job == null)
{
return queueResponse;
await Task.Delay(1000);
job = jobs.GetJob(diffResponse.Id);
}

// We know at this point, we definitely have a job
JobResult? queuedJob = job;

var completedResponse = new GenerationResponse { };

// Wait for the job to complete max 1 minute
Expand Down
20 changes: 10 additions & 10 deletions AiServer.ServiceInterface/ImageServices.Generation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public partial class ImageServices(IBackgroundJobs jobs,
ILogger<ImageServices> log,
AppData appData) : Service
{
public async Task<object> Any(QueueTextToImage request)
public async Task<QueueGenerationResponse> Any(QueueTextToImage request)
{
if (!string.IsNullOrEmpty(request.Model) && !appData.ModelSupportsTask(request.Model, AiTaskType.TextToImage))
{
Expand Down Expand Up @@ -44,10 +44,10 @@ public async Task<object> Any(QueueTextToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessQueuedGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageUpscale request)
public async Task<QueueGenerationResponse> Any(QueueImageUpscale request)
{
if(Request?.Files == null || Request.Files.Length == 0)
{
Expand All @@ -70,10 +70,10 @@ public async Task<object> Any(QueueImageUpscale request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessQueuedGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageToImage request)
public async Task<QueueGenerationResponse> Any(QueueImageToImage request)
{
if (Request?.Files == null || Request.Files.Length == 0)
{
Expand All @@ -100,10 +100,10 @@ public async Task<object> Any(QueueImageToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessQueuedGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageWithMask request)
public async Task<QueueGenerationResponse> Any(QueueImageWithMask request)
{
if (Request?.Files == null || Request.Files.Length > 2)
{
Expand All @@ -129,10 +129,10 @@ public async Task<object> Any(QueueImageWithMask request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessQueuedGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageToText request)
public async Task<QueueGenerationResponse> Any(QueueImageToText request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -147,7 +147,7 @@ public async Task<object> Any(QueueImageToText request)
};

await using var genServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, genServices);
return await diffRequest.ProcessQueuedGenerationAsync(jobs, genServices);
}
}

Expand Down
Loading

0 comments on commit 382c782

Please sign in to comment.