@@ -106,6 +106,8 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
106106 // Denoised result
107107 DenseTensor < float > denoised = null ;
108108
109+ var unetSession = _onnxModelService . GetOnnxSession ( OnnxModelType . Unet ) ;
110+
109111 // Loop though the timesteps
110112 var step = 0 ;
111113 foreach ( var timestep in timesteps )
@@ -118,7 +120,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
118120 var inputTensor = scheduler . ScaleInput ( latents , timestep ) ;
119121
120122 // Create Input Parameters
121- var inputParameters = CreateUnetInputParams ( modelOptions , inputTensor , promptEmbeddings , guidanceEmbeddings , timestep ) ;
123+ var inputParameters = CreateUnetInputParams ( modelOptions , inputTensor , promptEmbeddings , guidanceEmbeddings , timestep , unetSession ) ;
122124
123125 // Run Inference
124126 using ( var inferResult = await _onnxModelService . RunInferenceAsync ( modelOptions , OnnxModelType . Unet , inputParameters ) )
@@ -190,16 +192,25 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
190192 /// <param name="promptEmbeddings">The prompt embeddings.</param>
191193 /// <param name="timestep">The timestep.</param>
192194 /// <returns></returns>
193- protected virtual IReadOnlyList < NamedOnnxValue > CreateUnetInputParams ( IModelOptions model , DenseTensor < float > inputTensor , DenseTensor < float > promptEmbeddings , DenseTensor < float > guidanceEmbeddings , int timestep )
195+ protected virtual IReadOnlyList < NamedOnnxValue > CreateUnetInputParams ( IModelOptions model , DenseTensor < float > inputTensor , DenseTensor < float > promptEmbeddings , DenseTensor < float > guidanceEmbeddings , int timestep , InferenceSession unetSession )
194196 {
195197 var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . Unet ) ;
196198 return CreateInputParameters (
197199 NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , inputTensor ) ,
198- NamedOnnxValue . CreateFromTensor ( inputNames [ 1 ] , new DenseTensor < long > ( new long [ ] { timestep } , new int [ ] { 1 } ) ) ,
200+ GetTimestep ( model , timestep , inputNames , unetSession ) ,
199201 NamedOnnxValue . CreateFromTensor ( inputNames [ 2 ] , promptEmbeddings ) ,
200202 NamedOnnxValue . CreateFromTensor ( inputNames [ 3 ] , guidanceEmbeddings ) ) ;
201203 }
202204
205+ protected virtual NamedOnnxValue GetTimestep ( IModelOptions model , int timestep , IReadOnlyList < string > inputNames , InferenceSession unetSession )
206+ {
207+ TensorElementType elementDataType = unetSession . InputMetadata [ "timestep" ] . ElementDataType ;
208+ if ( elementDataType == TensorElementType . Float )
209+ return NamedOnnxValue . CreateFromTensor ( inputNames [ 1 ] , new DenseTensor < float > ( new float [ ] { timestep } , new int [ ] { 1 } ) ) ;
210+ if ( elementDataType == TensorElementType . Int64 )
211+ return NamedOnnxValue . CreateFromTensor ( inputNames [ 1 ] , new DenseTensor < long > ( new long [ ] { timestep } , new int [ ] { 1 } ) ) ;
212+ throw new Exception ( $ "Unable to map timestep to tensor value `{ elementDataType } `") ;
213+ }
203214
204215 /// <summary>
205216 /// Gets the scheduler.
0 commit comments