Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using Microsoft.ML.Data;
using static Microsoft.ML.Data.TextLoader;

using Range = Microsoft.ML.Data.TextLoader.Range;

namespace Microsoft.ML.AutoML
{
/// <summary>
Expand Down
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,9 @@ protected HostEnvironmentBase(HostEnvironmentBase<TEnv> source, Random rand, boo

// This fork shares some stuff with the master.
Master = source;
GpuDeviceId = Master?.GpuDeviceId;
FallbackToCpu = Master?.FallbackToCpu ?? true;
Seed = Master?.Seed;
Root = source.Root;
ListenerDict = source.ListenerDict;
ProgressTracker = source.ProgressTracker;
Expand Down
26 changes: 26 additions & 0 deletions src/Microsoft.ML.ImageAnalytics/MLImage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Transforms.Image;
using SkiaSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using static Microsoft.ML.Transforms.Image.ImagePixelExtractingEstimator;

namespace Microsoft.ML.Data
{
Expand Down Expand Up @@ -126,6 +128,30 @@ private set
}
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "<Pending>")]
public byte[] GetBGRPixels
{
get
{
ThrowInvalidOperationExceptionIfDisposed();

// 3 is because we only want RGB not alpha channels
byte[] pixels = new byte[Height * Width * 3];

var pixelData = _image.Pixels;
int idx = 0;
for (int i = 0; i < Height * Width * 3;)
{

pixels[i++] = pixelData[idx].Blue;
pixels[i++] = pixelData[idx].Green;
pixels[i++] = pixelData[idx++].Red;
}

return pixels;
}
}

/// <summary>
/// Gets the image pixel data.
/// </summary>
Expand Down
153 changes: 153 additions & 0 deletions src/Microsoft.ML.TorchSharp/AutoFormerV2/Anchors.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using TorchSharp;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;

namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
/// <summary>
/// Anchor boxes are a set of predefined bounding boxes of a certain height and width, whose location and size can be adjusted by the regression head of model.
/// </summary>
public class Anchors : Module<Tensor, Tensor>
{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
private readonly int[] pyramidLevels;

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
private readonly int[] strides;

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
private readonly int[] sizes;

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
private readonly double[] ratios;

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:private field names not in _camelCase format", Justification = "Need to match TorchSharp.")]
private readonly double[] scales;

/// <summary>
/// Initializes a new instance of the <see cref="Anchors"/> class.
/// </summary>
/// <param name="pyramidLevels">Pyramid levels.</param>
/// <param name="strides">Strides between adjacent bboxes.</param>
/// <param name="sizes">Different sizes for bboxes.</param>
/// <param name="ratios">Different ratios for height/width.</param>
/// <param name="scales">Scale size of bboxes.</param>
public Anchors(int[] pyramidLevels = null, int[] strides = null, int[] sizes = null, double[] ratios = null, double[] scales = null)
: base(nameof(Anchors))
{
this.pyramidLevels = pyramidLevels != null ? pyramidLevels : new int[] { 3, 4, 5, 6, 7 };
this.strides = strides != null ? strides : this.pyramidLevels.Select(x => (int)Math.Pow(2, x)).ToArray();
this.sizes = sizes != null ? sizes : this.pyramidLevels.Select(x => (int)Math.Pow(2, x + 2)).ToArray();
this.ratios = ratios != null ? ratios : new double[] { 0.5, 1, 2 };
this.scales = scales != null ? scales : new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) };
}

/// <summary>
/// Generate anchors for an image.
/// </summary>
/// <param name="image">Image in Tensor format.</param>
/// <returns>All anchors.</returns>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override Tensor forward(Tensor image)
{
using (var scope = torch.NewDisposeScope())
{
var imageShape = torch.tensor(image.shape.AsSpan().Slice(2).ToArray());

// compute anchors over all pyramid levels
var allAnchors = torch.zeros(new long[] { 0, 4 }, dtype: torch.float32);

for (int idx = 0; idx < this.pyramidLevels.Length; ++idx)
{
var x = this.pyramidLevels[idx];
var shape = ((imageShape + Math.Pow(2, x) - 1) / Math.Pow(2, x)).to_type(torch.int32);
var anchors = GenerateAnchors(
baseSize: this.sizes[idx],
ratios: this.ratios,
scales: this.scales);
var shiftedAnchors = Shift(shape, this.strides[idx], anchors);
allAnchors = torch.cat(new List<Tensor>() { allAnchors, shiftedAnchors }, dim: 0);
}

var output = allAnchors.unsqueeze(dim: 0);
output = output.to(image.device);

return output.MoveToOuterDisposeScope();
}
}

/// <summary>
/// Generate a set of anchors given size, ratios and scales.
/// </summary>
/// <param name="baseSize">Base size for width and height.</param>
/// <param name="ratios">Ratios for height/width.</param>
/// <param name="scales">Scales to resize base size.</param>
/// <returns>A set of anchors.</returns>
private static Tensor GenerateAnchors(int baseSize = 16, double[] ratios = null, double[] scales = null)
{
using (var anchorsScope = torch.NewDisposeScope())
{
ratios ??= new double[] { 0.5, 1, 2 };
scales ??= new double[] { Math.Pow(2, 0), Math.Pow(2, 1.0 / 3.0), Math.Pow(2, 2.0 / 3.0) };

var numAnchors = ratios.Length * scales.Length;

// initialize output anchors
var anchors = torch.zeros(new long[] { numAnchors, 4 }, dtype: torch.float32);

// scale base_size
anchors[.., 2..] = baseSize * torch.tile(scales, new long[] { 2, ratios.Length }).transpose(1, 0);

// compute areas of anchors
var areas = torch.mul(anchors[.., 2], anchors[.., 3]);

// correct for ratios
anchors[.., 2] = torch.sqrt(areas / torch.repeat_interleave(ratios, new long[] { scales.Length }));
anchors[.., 3] = torch.mul(anchors[.., 2], torch.repeat_interleave(ratios, new long[] { scales.Length }));

// transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 0, 2 }, dtype: torch.int64))] -= torch.tile(anchors[.., 2] * 0.5, new long[] { 2, 1 }).T;
anchors[.., torch.TensorIndex.Tensor(torch.tensor(new long[] { 1, 3 }, dtype: torch.int64))] -= torch.tile(anchors[.., 3] * 0.5, new long[] { 2, 1 }).T;

return anchors.MoveToOuterDisposeScope();
}
}

/// <summary>
/// Duplicate and distribute anchors to different positions give border of positions and stride between positions.
/// </summary>
/// <param name="shape">Border to distribute anchors.</param>
/// <param name="stride">Stride between adjacent anchors.</param>
/// <param name="anchors">Anchors to distribute.</param>
/// <returns>The shifted anchors.</returns>
private static Tensor Shift(Tensor shape, int stride, Tensor anchors)
{
using (var anchorsScope = torch.NewDisposeScope())
{
Tensor shiftX = (torch.arange(start: 0, stop: (int)shape[1]) + 0.5) * stride;
Tensor shiftY = (torch.arange(start: 0, stop: (int)shape[0]) + 0.5) * stride;

var shiftXExpand = torch.repeat_interleave(shiftX.reshape(new long[] { shiftX.shape[0], 1 }), shiftY.shape[0], dim: 1);
shiftXExpand = shiftXExpand.transpose(0, 1).reshape(-1);
var shiftYExpand = torch.repeat_interleave(shiftY, shiftX.shape[0]);

List<Tensor> tensors = new List<Tensor> { shiftXExpand, shiftYExpand, shiftXExpand, shiftYExpand };
var shifts = torch.vstack(tensors).transpose(0, 1);

var a = anchors.shape[0];
var k = shifts.shape[0];
var allAnchors = anchors.reshape(new long[] { 1, a, 4 }) + shifts.reshape(new long[] { 1, k, 4 }).transpose(0, 1);
allAnchors = allAnchors.reshape(new long[] { k * a, 4 });

return allAnchors.MoveToOuterDisposeScope();
}
}
}
}
135 changes: 135 additions & 0 deletions src/Microsoft.ML.TorchSharp/AutoFormerV2/Attention.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;

namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
/// <summary>
/// The Attention layer.
/// </summary>
public class Attention : Module<Tensor, Tensor, Tensor>
{
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
private readonly int numHeads;
private readonly double scale;
private readonly int keyChannels;
private readonly int nHkD;
private readonly int d;
private readonly int dh;
private readonly double attnRatio;

private readonly LayerNorm norm;
private readonly Linear qkv;
private readonly Linear proj;
private readonly Parameter attention_biases;
private readonly TensorIndex attention_bias_idxs;
private readonly Softmax softmax;
#pragma warning restore MSML_PrivateFieldName


/// <summary>
/// Initializes a new instance of the <see cref="Attention"/> class.
/// </summary>
/// <param name="inChannels">The input channels.</param>
/// <param name="keyChannels">The key channels.</param>
/// <param name="numHeads">The number of blocks.</param>
/// <param name="attnRatio">The ratio of attention.</param>
/// <param name="windowResolution">The resolution of window.</param>
public Attention(int inChannels, int keyChannels, int numHeads = 8, int attnRatio = 4, List<int> windowResolution = null)
: base(nameof(Attention))
{
windowResolution ??= new List<int>() { 14, 14 };
this.numHeads = numHeads;
this.scale = System.Math.Pow(keyChannels, -0.5);
this.keyChannels = keyChannels;
this.nHkD = numHeads * keyChannels;
this.d = attnRatio * keyChannels;
this.dh = this.d * numHeads;
this.attnRatio = attnRatio;
int h = this.dh + (this.nHkD * 2);

this.norm = nn.LayerNorm(new long[] { inChannels });
this.qkv = nn.Linear(inChannels, h);
this.proj = nn.Linear(this.dh, inChannels);

var points = new List<List<int>>();
for (int i = 0; i < windowResolution[0]; i++)
{
for (int j = 0; j < windowResolution[1]; j++)
{
points.Add(new List<int>() { i, j });
}
}

int n = points.Count;
var attentionOffsets = new Dictionary<Tuple<int, int>, int>();
var idxs = new List<int>();
var idxsTensor = torch.zeros(new long[] { n, n }, dtype: torch.int64);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
var offset = new Tuple<int, int>(Math.Abs(points[i][0] - points[j][0]), Math.Abs(points[i][1] - points[j][1]));
if (!attentionOffsets.ContainsKey(offset))
{
attentionOffsets.Add(offset, attentionOffsets.Count);
}

idxs.Add(attentionOffsets[offset]);
idxsTensor[i][j] = attentionOffsets[offset];
}
}

this.attention_biases = nn.Parameter(torch.zeros(numHeads, attentionOffsets.Count));
this.attention_bias_idxs = TensorIndex.Tensor(idxsTensor);
this.softmax = nn.Softmax(dim: -1);
}

/// <inheritdoc/>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override Tensor forward(Tensor x, Tensor mask)
{
using (var scope = torch.NewDisposeScope())
{
long b = x.shape[0];
long n = x.shape[1];
long c = x.shape[2];
x = this.norm.forward(x);
var qkv = this.qkv.forward(x);
qkv = qkv.view(b, n, this.numHeads, -1);
var tmp = qkv.split(new long[] { this.keyChannels, this.keyChannels, this.d }, dim: 3);
var q = tmp[0];
var k = tmp[1];
var v = tmp[2];
q = q.permute(0, 2, 1, 3);
k = k.permute(0, 2, 1, 3);
v = v.permute(0, 2, 1, 3);

var attn = (torch.matmul(q, k.transpose(-2, -1)) * this.scale) + this.attention_biases[.., this.attention_bias_idxs];
if (!(mask is null))
{
long nW = mask.shape[0];
attn = attn.view(-1, nW, this.numHeads, n, n) + mask.unsqueeze(1).unsqueeze(0);
attn = attn.view(-1, this.numHeads, n, n);
attn = this.softmax.forward(attn);
}
else
{
attn = this.softmax.forward(attn);
}

x = torch.matmul(attn, v).transpose(1, 2).reshape(b, n, this.dh);
x = this.proj.forward(x);

return x.MoveToOuterDisposeScope();
}
}
}
}
Loading