Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 88 additions & 21 deletions lib/SILOptimizer/Mandatory/TFLowerGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,11 +807,6 @@ static TF_Tensor *convertValuesToTensor(ArrayRef<APInt> elts,
auto *ptr = (char *)TF_TensorData(tensor);

for (auto elt : elts) {
if (dtype == TF_BOOL) {
// Swift/LLVM uses a 1-bit APInt to represent a bool, but TF_BOOL is 1
// byte or more.
elt = elt.zext(dtypeSize * 8);
}
assert(elt.getBitWidth() == dtypeSize * 8);
memcpy(ptr, elt.getRawData(), dtypeSize);
ptr += dtypeSize;
Expand Down Expand Up @@ -1881,9 +1876,58 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) {
break;
}
case SymbolicValue::Array: {
// FIXME: Handle array attributes.
CanType elementType;
auto rawElements = attrValue.getArrayValue(elementType);
SmallVector<SymbolicValue, 4> elements;
elements.reserve(rawElements.size());
for (auto elt : rawElements)
elements.push_back(elt.lookThroughSingleElementAggregates());

auto elementTypeString = elementType->getString();

// TODO: TF_SetAttrTypeList

if (elementTypeString == "String") {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any enum that we can match on, rather than doing (sub)string comparison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, not really (that I know of). This really comes down to what types we want to allow in #tfop attribute lists. We could tighten this up to check that the types are coming from Swift module. I'll dig around to see what ASTContext has to work with here.

SmallVector<const void*, 4> pointers;
SmallVector<size_t, 4> sizes;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to call reserve() for pointers and sizes as well? same for values below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done.

for (auto elt : elements) {
auto bytes = elt.getStringValue();
pointers.push_back(bytes.data());
sizes.push_back(bytes.size());
}
TF_SetAttrStringList(op, name.c_str(), pointers.data(), sizes.data(),
elements.size());
break;
}
if (StringRef(elementTypeString).startswith("Int")) {
SmallVector<int64_t, 4> values;
for (auto elt : elements)
values.push_back(elt.getIntegerValue().getLimitedValue());
TF_SetAttrIntList(op, name.c_str(), values.data(), values.size());
break;
}
if (elementTypeString == "Float" || elementTypeString == "Double") {
SmallVector<float, 4> values;
for (auto elt : elements) {
auto value = elt.getFloatValue();
bool losesInfo = false;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment on why losing precision here is ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it really isn't, but this is tensorflow.... :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(but sure, I'll add a comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Some of this could be refactored into helper functions. I guess we can do it later.

value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
&losesInfo);
values.push_back(value.convertToFloat());
}
TF_SetAttrFloatList(op, name.c_str(), values.data(), values.size());
break;
}
if (elementTypeString == "Bool") {
SmallVector<unsigned char, 4> values;
for (auto elt : elements)
values.push_back(elt.getIntegerValue().getLimitedValue() != 0);
TF_SetAttrBoolList(op, name.c_str(), values.data(), values.size());
break;
}

internalError(getUserSourceLocation(inst->getDebugLocation()),
"FIXME: Handle array attributes");
"unknown array attribute");
return GLStatus::Error;
}
}
Expand All @@ -1900,42 +1944,67 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) {
inst->dump();
llvm_unreachable("dtype attr must have been processed!");
}
auto dtype = (TF_DataType)dtypeAttr;
auto dtypeSize = TF_DataTypeSize(dtype);

// Tensor can support two cases: an array case (not yet implemented), and
// a scalar case.
SmallVector<APInt, 4> elements;
SmallVector<int64_t, 4> shape;

// The scalar case is very simple, the shape of a scalar is 0d, and the
// data type comes from an attr that should already be processed.
auto addScalar = [&](SymbolicValue value) {
// Add a scalar to the elements list, checking that it is the right size
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"checking that it is the right size" -> should we update the comment as we are actually fixing the sizes in the impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// for our dtype.
auto addScalar = [&](SymbolicValue value) -> bool {
value = value.lookThroughSingleElementAggregates();

if (value.getKind() == SymbolicValue::Integer) {
auto intVal = value.getIntegerValue();
if (dtype == TF_BOOL) {
// Swift/LLVM uses a 1-bit APInt to represent a bool, but TF_BOOL is
// 1 byte or more.
intVal = intVal.zext(dtypeSize * 8);
}
elements.push_back(intVal);
} else {
assert(value.getKind() == SymbolicValue::Float);
auto castedIntVal = value.getFloatValue().bitcastToAPInt();
elements.push_back(castedIntVal);
auto floatValue = value.getFloatValue();
bool losesInfo = false;
// Convert to float if necessary.
if (dtype == TF_FLOAT)
floatValue.convert(APFloat::IEEEsingle(),
APFloat::rmNearestTiesToEven, &losesInfo);
elements.push_back(floatValue.bitcastToAPInt());
}

if (elements.back().getBitWidth() != dtypeSize*8) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dtypeSize * 8 instead of dtypeSize*8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

internalError(inst->getLoc(),
"invalid element type for tensor dtype");
return true;
}

return false;
};

// The scalar case is very simple, the shape of a scalar is 0d, and the
// data type comes from an attr that should already be processed.
SmallVector<int64_t, 4> shape;
if (attrValue.getKind() == SymbolicValue::Integer ||
attrValue.getKind() == SymbolicValue::Float) {
addScalar(attrValue);
if (addScalar(attrValue))
return GLStatus::Error;
} else {
// Add all the elements to the elements list.
CanType eltType;
for (auto elt : attrValue.getArrayValue(eltType))
addScalar(elt);
for (auto elt : attrValue.getArrayValue(eltType)) {
if (addScalar(elt))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: i'd prefer that we pass elements explicitly into addScalar call here, so that the code here looks less "magical". Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

return GLStatus::Error;
}

// Decode the shape attribute which must come next.
auto shapeAttr = inst->getAttribute(nextAttributeNumber++).value;
decodeShapeAttr(shapeAttr, shape);
}
// Set the tensor as the attribute on the graph node.
auto tensor =
convertValuesToTensor(elements, shape, (TF_DataType)dtypeAttr);
auto tensor = convertValuesToTensor(elements, shape, dtype);
TF_SetAttrTensor(op, name.c_str(), tensor, status);
TF_DeleteTensor(tensor);
if (checkStatus(inst->getLoc()))
Expand All @@ -1951,13 +2020,11 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) {

case SILTensorOpInfo::OperandClass::ShapeArray:
// TODO: TF_SetAttrShapeList
case SILTensorOpInfo::OperandClass::Array:
// TODO: TF_SetAttrTypeList, TF_SetAttrStringList, TF_SetAttrIntList,
// TF_SetAttrFloatList, TF_SetAttrBoolList
internalError(getUserSourceLocation(inst->getDebugLocation()),
"FIXME: Handle array and shapearray attributes");
return GLStatus::Error;

case SILTensorOpInfo::OperandClass::Array: // Handled as 'normal'
case SILTensorOpInfo::OperandClass::ArrayElement:
llvm_unreachable("This is a legacy class that shouldn't happen");
}
Expand Down
1 change: 0 additions & 1 deletion test/TensorFlow/deabstraction_finished.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ public func test75407624() {


public func testConvolution(x: Tensor<Float>, filter: Tensor<Float>) -> Tensor<Float> {
// expected-error @+1 {{FIXME: Handle array attributes}}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!!

return x.toAccelerator().convolved2D(withFilter: filter.toAccelerator(),
strides: (1, 2, 3, 4), padding: .same)
}
Expand Down