Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 46b8ccb

Browse files
committed
Modify timestep tensor type to support LCM fp16
1 parent e56a7b6 commit 46b8ccb

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

OnnxStack.Core/Services/IOnnxModelService.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,12 @@ public interface IOnnxModelService : IDisposable
108108
/// <param name="modelType">Type of model.</param>
109109
/// <returns></returns>
110110
IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType);
111+
112+
/// <summary>
113+
/// Gets the InferenceSession
114+
/// </summary>
115+
/// <param name="type">Type of the InferenceSession to get</param>
116+
/// <returns></returns>
117+
InferenceSession GetOnnxSession(OnnxModelType type);
111118
}
112119
}

OnnxStack.Core/Services/OnnxModelService.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType mode
174174
return OutputNamesInternal(model, modelType);
175175
}
176176

177+
/// <summary>
178+
/// Gets the InferenceSession
179+
/// </summary>
180+
/// <param name="type">Type of the InferenceSession to get</param>
181+
/// <returns></returns>
182+
public InferenceSession GetOnnxSession(OnnxModelType type)
183+
{
184+
return _onnxModelSets.First().Value.GetSession(type);
185+
}
177186

178187
/// <summary>
179188
/// Runs inference on the specified model.
@@ -183,9 +192,16 @@ public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType mode
183192
/// <returns></returns>
184193
private IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunInternal(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs)
185194
{
186-
return GetModelSet(model)
187-
.GetSession(modelType)
188-
.Run(inputs);
195+
try
196+
{
197+
return GetModelSet(model)
198+
.GetSession(modelType)
199+
.Run(inputs);
200+
}
201+
catch (Exception ex)
202+
{
203+
return default;
204+
}
189205
}
190206

191207

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
106106
// Denoised result
107107
DenseTensor<float> denoised = null;
108108

109+
var unetSession = _onnxModelService.GetOnnxSession(OnnxModelType.Unet);
110+
109111
// Loop though the timesteps
110112
var step = 0;
111113
foreach (var timestep in timesteps)
@@ -118,7 +120,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
118120
var inputTensor = scheduler.ScaleInput(latents, timestep);
119121

120122
// Create Input Parameters
121-
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep);
123+
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep, unetSession);
122124

123125
// Run Inference
124126
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
@@ -190,16 +192,25 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
190192
/// <param name="promptEmbeddings">The prompt embeddings.</param>
191193
/// <param name="timestep">The timestep.</param>
192194
/// <returns></returns>
193-
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
195+
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep, InferenceSession unetSession)
194196
{
195197
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
196198
return CreateInputParameters(
197199
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
198-
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
200+
GetTimestep(model, timestep, inputNames, unetSession),
199201
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
200202
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
201203
}
202204

205+
protected virtual NamedOnnxValue GetTimestep(IModelOptions model, int timestep, IReadOnlyList<string> inputNames, InferenceSession unetSession)
206+
{
207+
TensorElementType elementDataType = unetSession.InputMetadata["timestep"].ElementDataType;
208+
if (elementDataType == TensorElementType.Float)
209+
return NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
210+
if (elementDataType == TensorElementType.Int64)
211+
return NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }));
212+
throw new Exception($"Unable to map timestep to tensor value `{elementDataType}`");
213+
}
203214

204215
/// <summary>
205216
/// Gets the scheduler.

0 commit comments

Comments
 (0)