1
1
using Microsoft . Extensions . Logging ;
2
+ using Microsoft . ML . OnnxRuntime . Tensors ;
3
+ using OnnxStack . Core ;
2
4
using OnnxStack . Core . Model ;
3
5
using OnnxStack . StableDiffusion . Common ;
4
6
using OnnxStack . StableDiffusion . Config ;
5
7
using OnnxStack . StableDiffusion . Diffusers . StableDiffusionXL ;
6
8
using OnnxStack . StableDiffusion . Enums ;
7
9
using OnnxStack . StableDiffusion . Models ;
8
10
using OnnxStack . StableDiffusion . Schedulers . LatentConsistency ;
11
+ using System . Diagnostics ;
12
+ using System . Linq ;
13
+ using System . Threading . Tasks ;
14
+ using System . Threading ;
15
+ using System ;
9
16
10
17
namespace OnnxStack . StableDiffusion . Diffusers . LatentConsistencyXL
11
18
{
@@ -29,6 +36,92 @@ protected LatentConsistencyXLDiffuser(UNetConditionModel unet, AutoEncoderModel
29
36
public override DiffuserPipelineType PipelineType => DiffuserPipelineType . LatentConsistencyXL ;
30
37
31
38
39
+ /// <summary>
40
+ /// Runs the scheduler steps.
41
+ /// </summary>
42
+ /// <param name="modelOptions">The model options.</param>
43
+ /// <param name="promptOptions">The prompt options.</param>
44
+ /// <param name="schedulerOptions">The scheduler options.</param>
45
+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
46
+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
47
+ /// <param name="progressCallback">The progress callback.</param>
48
+ /// <param name="cancellationToken">The cancellation token.</param>
49
+ /// <returns></returns>
50
+ public override async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
51
+ {
52
+ // Get Scheduler
53
+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
54
+ {
55
+ // Get timesteps
56
+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
57
+
58
+ // Create latent sample
59
+ var latents = await PrepareLatentsAsync ( promptOptions , schedulerOptions , scheduler , timesteps ) ;
60
+
61
+ // Get Model metadata
62
+ var metadata = await _unet . GetMetadataAsync ( ) ;
63
+
64
+ // Get Time ids
65
+ var addTimeIds = GetAddTimeIds ( schedulerOptions ) ;
66
+
67
+ // Get Guidance Scale Embedding
68
+ var guidanceEmbeddings = GetGuidanceScaleEmbedding ( schedulerOptions . GuidanceScale ) ;
69
+
70
+ // Loop though the timesteps
71
+ var step = 0 ;
72
+ foreach ( var timestep in timesteps )
73
+ {
74
+ step ++ ;
75
+ var stepTime = Stopwatch . GetTimestamp ( ) ;
76
+ cancellationToken . ThrowIfCancellationRequested ( ) ;
77
+
78
+ // Create input tensor.
79
+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
80
+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
81
+ var timestepTensor = CreateTimestepTensor ( timestep ) ;
82
+ var timeids = performGuidance ? addTimeIds . Repeat ( 2 ) : addTimeIds ;
83
+
84
+ var outputChannels = performGuidance ? 2 : 1 ;
85
+ var outputDimension = schedulerOptions . GetScaledDimension ( outputChannels ) ;
86
+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
87
+ {
88
+ inferenceParameters . AddInputTensor ( inputTensor ) ;
89
+ inferenceParameters . AddInputTensor ( timestepTensor ) ;
90
+ inferenceParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
91
+ if ( inferenceParameters . InputCount == 6 )
92
+ inferenceParameters . AddInputTensor ( guidanceEmbeddings ) ;
93
+ inferenceParameters . AddInputTensor ( promptEmbeddings . PooledPromptEmbeds ) ;
94
+ inferenceParameters . AddInputTensor ( timeids ) ;
95
+ inferenceParameters . AddOutputBuffer ( outputDimension ) ;
96
+
97
+ var results = await _unet . RunInferenceAsync ( inferenceParameters ) ;
98
+ using ( var result = results . First ( ) )
99
+ {
100
+ var noisePred = result . ToDenseTensor ( ) ;
101
+
102
+ // Perform guidance
103
+ if ( performGuidance )
104
+ noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
105
+
106
+ // Scheduler Step
107
+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
108
+ }
109
+ }
110
+
111
+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
112
+ _logger ? . LogEnd ( LogLevel . Debug , $ "Step { step } /{ timesteps . Count } ", stepTime ) ;
113
+ }
114
+
115
+ // Unload if required
116
+ if ( _memoryMode == MemoryModeType . Minimum )
117
+ await _unet . UnloadAsync ( ) ;
118
+
119
+ // Decode Latents
120
+ return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latents ) ;
121
+ }
122
+ }
123
+
124
+
32
125
/// <summary>
33
126
/// Gets the scheduler.
34
127
/// </summary>
@@ -42,5 +135,26 @@ protected override IScheduler GetScheduler(SchedulerOptions options)
42
135
_ => default
43
136
} ;
44
137
}
138
+
139
+
140
+ /// <summary>
141
+ /// Gets the guidance scale embedding.
142
+ /// </summary>
143
+ /// <param name="options">The options.</param>
144
+ /// <param name="embeddingDim">The embedding dim.</param>
145
+ /// <returns></returns>
146
+ protected DenseTensor < float > GetGuidanceScaleEmbedding ( float guidance , int embeddingDim = 256 )
147
+ {
148
+ var scale = ( guidance - 1f ) * 1000.0f ;
149
+ var halfDim = embeddingDim / 2 ;
150
+ float log = MathF . Log ( 10000.0f ) / ( halfDim - 1 ) ;
151
+ var emb = Enumerable . Range ( 0 , halfDim )
152
+ . Select ( x => scale * MathF . Exp ( - log * x ) )
153
+ . ToArray ( ) ;
154
+ var embSin = emb . Select ( MathF . Sin ) ;
155
+ var embCos = emb . Select ( MathF . Cos ) ;
156
+ var guidanceEmbedding = embSin . Concat ( embCos ) . ToArray ( ) ;
157
+ return new DenseTensor < float > ( guidanceEmbedding , new [ ] { 1 , embeddingDim } ) ;
158
+ }
45
159
}
46
160
}
0 commit comments