Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 2c0eeae

Browse files
authored
Merge pull request #123 from saddam213/SDXL_LCM
Support LCM-SDXL guidance embeddings
2 parents a4fb47c + e2acc3e commit 2c0eeae

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

OnnxStack.StableDiffusion/Diffusers/LatentConsistencyXL/LatentConsistencyXLDiffuser.cs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
24
using OnnxStack.Core.Model;
35
using OnnxStack.StableDiffusion.Common;
46
using OnnxStack.StableDiffusion.Config;
57
using OnnxStack.StableDiffusion.Diffusers.StableDiffusionXL;
68
using OnnxStack.StableDiffusion.Enums;
79
using OnnxStack.StableDiffusion.Models;
810
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
11+
using System.Diagnostics;
12+
using System.Linq;
13+
using System.Threading.Tasks;
14+
using System.Threading;
15+
using System;
916

1017
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistencyXL
1118
{
@@ -29,6 +36,92 @@ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel
2936
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.LatentConsistencyXL;
3037

3138

39+
/// <summary>
40+
/// Runs the scheduler steps.
41+
/// </summary>
42+
/// <param name="modelOptions">The model options.</param>
43+
/// <param name="promptOptions">The prompt options.</param>
44+
/// <param name="schedulerOptions">The scheduler options.</param>
45+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
46+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
47+
/// <param name="progressCallback">The progress callback.</param>
48+
/// <param name="cancellationToken">The cancellation token.</param>
49+
/// <returns></returns>
50+
public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
51+
{
52+
// Get Scheduler
53+
using (var scheduler = GetScheduler(schedulerOptions))
54+
{
55+
// Get timesteps
56+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
57+
58+
// Create latent sample
59+
var latents = await PrepareLatentsAsync(promptOptions, schedulerOptions, scheduler, timesteps);
60+
61+
// Get Model metadata
62+
var metadata = await _unet.GetMetadataAsync();
63+
64+
// Get Time ids
65+
var addTimeIds = GetAddTimeIds(schedulerOptions);
66+
67+
// Get Guidance Scale Embedding
68+
var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);
69+
70+
// Loop though the timesteps
71+
var step = 0;
72+
foreach (var timestep in timesteps)
73+
{
74+
step++;
75+
var stepTime = Stopwatch.GetTimestamp();
76+
cancellationToken.ThrowIfCancellationRequested();
77+
78+
// Create input tensor.
79+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
80+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
81+
var timestepTensor = CreateTimestepTensor(timestep);
82+
var timeids = performGuidance ? addTimeIds.Repeat(2) : addTimeIds;
83+
84+
var outputChannels = performGuidance ? 2 : 1;
85+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
86+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
87+
{
88+
inferenceParameters.AddInputTensor(inputTensor);
89+
inferenceParameters.AddInputTensor(timestepTensor);
90+
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
91+
if (inferenceParameters.InputCount == 6)
92+
inferenceParameters.AddInputTensor(guidanceEmbeddings);
93+
inferenceParameters.AddInputTensor(promptEmbeddings.PooledPromptEmbeds);
94+
inferenceParameters.AddInputTensor(timeids);
95+
inferenceParameters.AddOutputBuffer(outputDimension);
96+
97+
var results = await _unet.RunInferenceAsync(inferenceParameters);
98+
using (var result = results.First())
99+
{
100+
var noisePred = result.ToDenseTensor();
101+
102+
// Perform guidance
103+
if (performGuidance)
104+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
105+
106+
// Scheduler Step
107+
latents = scheduler.Step(noisePred, timestep, latents).Result;
108+
}
109+
}
110+
111+
ReportProgress(progressCallback, step, timesteps.Count, latents);
112+
_logger?.LogEnd(LogLevel.Debug, $"Step {step}/{timesteps.Count}", stepTime);
113+
}
114+
115+
// Unload if required
116+
if (_memoryMode == MemoryModeType.Minimum)
117+
await _unet.UnloadAsync();
118+
119+
// Decode Latents
120+
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
121+
}
122+
}
123+
124+
32125
/// <summary>
33126
/// Gets the scheduler.
34127
/// </summary>
@@ -42,5 +135,26 @@ protected override IScheduler GetScheduler(SchedulerOptions options)
42135
_ => default
43136
};
44137
}
138+
139+
140+
/// <summary>
141+
/// Gets the guidance scale embedding.
142+
/// </summary>
143+
/// <param name="options">The options.</param>
144+
/// <param name="embeddingDim">The embedding dim.</param>
145+
/// <returns></returns>
146+
protected DenseTensor<float> GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256)
147+
{
148+
var scale = (guidance - 1f) * 1000.0f;
149+
var halfDim = embeddingDim / 2;
150+
float log = MathF.Log(10000.0f) / (halfDim - 1);
151+
var emb = Enumerable.Range(0, halfDim)
152+
.Select(x => scale * MathF.Exp(-log * x))
153+
.ToArray();
154+
var embSin = emb.Select(MathF.Sin);
155+
var embCos = emb.Select(MathF.Cos);
156+
var guidanceEmbedding = embSin.Concat(embCos).ToArray();
157+
return new DenseTensor<float>(guidanceEmbedding, new[] { 1, embeddingDim });
158+
}
45159
}
46160
}

0 commit comments

Comments
 (0)