Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions OnnxStack.Core/Services/IOnnxModelService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,12 @@ public interface IOnnxModelService : IDisposable
/// <param name="modelType">Type of model.</param>
/// <returns></returns>
IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType modelType);

/// <summary>
/// Gets the InferenceSession
/// </summary>
/// <param name="type">Type of the InferenceSession to get</param>
/// <returns></returns>
InferenceSession GetOnnxSession(OnnxModelType type);
}
}
22 changes: 19 additions & 3 deletions OnnxStack.Core/Services/OnnxModelService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType mode
return OutputNamesInternal(model, modelType);
}

/// <summary>
/// Gets the InferenceSession
/// </summary>
/// <param name="type">Type of the InferenceSession to get</param>
/// <returns></returns>
public InferenceSession GetOnnxSession(OnnxModelType type)
{
return _onnxModelSets.First().Value.GetSession(type);
}

/// <summary>
/// Runs inference on the specified model.
Expand All @@ -183,9 +192,16 @@ public IReadOnlyList<string> GetOutputNames(IOnnxModel model, OnnxModelType mode
/// <returns></returns>
private IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunInternal(IOnnxModel model, OnnxModelType modelType, IReadOnlyCollection<NamedOnnxValue> inputs)
{
return GetModelSet(model)
.GetSession(modelType)
.Run(inputs);
try
{
return GetModelSet(model)
.GetSession(modelType)
.Run(inputs);
}
catch (Exception ex)
{
return default;
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
// Denoised result
DenseTensor<float> denoised = null;

var unetSession = _onnxModelService.GetOnnxSession(OnnxModelType.Unet);

// Loop though the timesteps
var step = 0;
foreach (var timestep in timesteps)
Expand All @@ -118,7 +120,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
var inputTensor = scheduler.ScaleInput(latents, timestep);

// Create Input Parameters
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep);
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep, unetSession);

// Run Inference
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
Expand Down Expand Up @@ -190,16 +192,25 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="timestep">The timestep.</param>
/// <returns></returns>
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep, InferenceSession unetSession)
{
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
return CreateInputParameters(
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
GetTimestep(model, timestep, inputNames, unetSession),
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
}

protected virtual NamedOnnxValue GetTimestep(IModelOptions model, int timestep, IReadOnlyList<string> inputNames, InferenceSession unetSession)
{
TensorElementType elementDataType = unetSession.InputMetadata["timestep"].ElementDataType;
if (elementDataType == TensorElementType.Float)
return NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));
if (elementDataType == TensorElementType.Int64)
return NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }));
throw new Exception($"Unable to map timestep to tensor value `{elementDataType}`");
}

/// <summary>
/// Gets the scheduler.
Expand Down