- Notifications
You must be signed in to change notification settings - Fork 10.6k
Implement lowergraph support for string/int/float/bool arrays #18041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
9af328c f53c26c 815f6b9 84ba040 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -807,11 +807,6 @@ static TF_Tensor *convertValuesToTensor(ArrayRef<APInt> elts, | |
| auto *ptr = (char *)TF_TensorData(tensor); | ||
| | ||
| for (auto elt : elts) { | ||
| if (dtype == TF_BOOL) { | ||
| // Swift/LLVM uses a 1-bit APInt to represent a bool, but TF_BOOL is 1 | ||
| // byte or more. | ||
| elt = elt.zext(dtypeSize * 8); | ||
| } | ||
| assert(elt.getBitWidth() == dtypeSize * 8); | ||
| memcpy(ptr, elt.getRawData(), dtypeSize); | ||
| ptr += dtypeSize; | ||
| | @@ -1881,9 +1876,58 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) { | |
| break; | ||
| } | ||
| case SymbolicValue::Array: { | ||
| // FIXME: Handle array attributes. | ||
| CanType elementType; | ||
| auto rawElements = attrValue.getArrayValue(elementType); | ||
| SmallVector<SymbolicValue, 4> elements; | ||
| elements.reserve(rawElements.size()); | ||
| for (auto elt : rawElements) | ||
| elements.push_back(elt.lookThroughSingleElementAggregates()); | ||
| | ||
| auto elementTypeString = elementType->getString(); | ||
| | ||
| // TODO: TF_SetAttrTypeList | ||
| | ||
| if (elementTypeString == "String") { | ||
| SmallVector<const void*, 4> pointers; | ||
| SmallVector<size_t, 4> sizes; | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to call reserve() for Contributor Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, done. | ||
| for (auto elt : elements) { | ||
| auto bytes = elt.getStringValue(); | ||
| pointers.push_back(bytes.data()); | ||
| sizes.push_back(bytes.size()); | ||
| } | ||
| TF_SetAttrStringList(op, name.c_str(), pointers.data(), sizes.data(), | ||
| elements.size()); | ||
| break; | ||
| } | ||
| if (StringRef(elementTypeString).startswith("Int")) { | ||
| SmallVector<int64_t, 4> values; | ||
| for (auto elt : elements) | ||
| values.push_back(elt.getIntegerValue().getLimitedValue()); | ||
| TF_SetAttrIntList(op, name.c_str(), values.data(), values.size()); | ||
| break; | ||
| } | ||
| if (elementTypeString == "Float" || elementTypeString == "Double") { | ||
| SmallVector<float, 4> values; | ||
| for (auto elt : elements) { | ||
| auto value = elt.getFloatValue(); | ||
| bool losesInfo = false; | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a comment on why losing precision here is ok? Contributor Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well it really isn't, but this is tensorflow.... :-) Contributor Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (but sure, I'll add a comment) Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Some of this could be refactored into helper functions. I guess we can do it later. | ||
| value.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, | ||
| &losesInfo); | ||
| values.push_back(value.convertToFloat()); | ||
| } | ||
| TF_SetAttrFloatList(op, name.c_str(), values.data(), values.size()); | ||
| break; | ||
| } | ||
| if (elementTypeString == "Bool") { | ||
| SmallVector<unsigned char, 4> values; | ||
| for (auto elt : elements) | ||
| values.push_back(elt.getIntegerValue().getLimitedValue() != 0); | ||
| TF_SetAttrBoolList(op, name.c_str(), values.data(), values.size()); | ||
| break; | ||
| } | ||
| | ||
| internalError(getUserSourceLocation(inst->getDebugLocation()), | ||
| "FIXME: Handle array attributes"); | ||
| "unknown array attribute"); | ||
| return GLStatus::Error; | ||
| } | ||
| } | ||
| | @@ -1900,42 +1944,67 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) { | |
| inst->dump(); | ||
| llvm_unreachable("dtype attr must have been processed!"); | ||
| } | ||
| auto dtype = (TF_DataType)dtypeAttr; | ||
| auto dtypeSize = TF_DataTypeSize(dtype); | ||
| | ||
| // Tensor can support two cases: an array case (not yet implemented), and | ||
| // a scalar case. | ||
| SmallVector<APInt, 4> elements; | ||
| SmallVector<int64_t, 4> shape; | ||
| | ||
| // The scalar case is very simple, the shape of a scalar is 0d, and the | ||
| // data type comes from an attr that should already be processed. | ||
| auto addScalar = [&](SymbolicValue value) { | ||
| // Add a scalar to the elements list, checking that it is the right size | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "checking that it is the right size" -> should we update the comment as we are actually fixing the sizes in the impl? Contributor Author There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. | ||
| // for our dtype. | ||
| auto addScalar = [&](SymbolicValue value) -> bool { | ||
| value = value.lookThroughSingleElementAggregates(); | ||
| | ||
| if (value.getKind() == SymbolicValue::Integer) { | ||
| auto intVal = value.getIntegerValue(); | ||
| if (dtype == TF_BOOL) { | ||
| // Swift/LLVM uses a 1-bit APInt to represent a bool, but TF_BOOL is | ||
| // 1 byte or more. | ||
| intVal = intVal.zext(dtypeSize * 8); | ||
| } | ||
| elements.push_back(intVal); | ||
| } else { | ||
| assert(value.getKind() == SymbolicValue::Float); | ||
| auto castedIntVal = value.getFloatValue().bitcastToAPInt(); | ||
| elements.push_back(castedIntVal); | ||
| auto floatValue = value.getFloatValue(); | ||
| bool losesInfo = false; | ||
| // Convert to float if necessary. | ||
| if (dtype == TF_FLOAT) | ||
| floatValue.convert(APFloat::IEEEsingle(), | ||
| APFloat::rmNearestTiesToEven, &losesInfo); | ||
| elements.push_back(floatValue.bitcastToAPInt()); | ||
| } | ||
| | ||
| if (elements.back().getBitWidth() != dtypeSize*8) { | ||
| ||
| internalError(inst->getLoc(), | ||
| "invalid element type for tensor dtype"); | ||
| return true; | ||
| } | ||
| | ||
| return false; | ||
| }; | ||
| | ||
| // The scalar case is very simple, the shape of a scalar is 0d, and the | ||
| // data type comes from an attr that should already be processed. | ||
| SmallVector<int64_t, 4> shape; | ||
| if (attrValue.getKind() == SymbolicValue::Integer || | ||
| attrValue.getKind() == SymbolicValue::Float) { | ||
| addScalar(attrValue); | ||
| if (addScalar(attrValue)) | ||
| return GLStatus::Error; | ||
| } else { | ||
| // Add all the elements to the elements list. | ||
| CanType eltType; | ||
| for (auto elt : attrValue.getArrayValue(eltType)) | ||
| addScalar(elt); | ||
| for (auto elt : attrValue.getArrayValue(eltType)) { | ||
| if (addScalar(elt)) | ||
| ||
| return GLStatus::Error; | ||
| } | ||
| | ||
| // Decode the shape attribute which must come next. | ||
| auto shapeAttr = inst->getAttribute(nextAttributeNumber++).value; | ||
| decodeShapeAttr(shapeAttr, shape); | ||
| } | ||
| // Set the tensor as the attribute on the graph node. | ||
| auto tensor = | ||
| convertValuesToTensor(elements, shape, (TF_DataType)dtypeAttr); | ||
| auto tensor = convertValuesToTensor(elements, shape, dtype); | ||
| TF_SetAttrTensor(op, name.c_str(), tensor, status); | ||
| TF_DeleteTensor(tensor); | ||
| if (checkStatus(inst->getLoc())) | ||
| | @@ -1951,13 +2020,11 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) { | |
| | ||
| case SILTensorOpInfo::OperandClass::ShapeArray: | ||
| // TODO: TF_SetAttrShapeList | ||
| case SILTensorOpInfo::OperandClass::Array: | ||
| // TODO: TF_SetAttrTypeList, TF_SetAttrStringList, TF_SetAttrIntList, | ||
| // TF_SetAttrFloatList, TF_SetAttrBoolList | ||
| internalError(getUserSourceLocation(inst->getDebugLocation()), | ||
| "FIXME: Handle array and shapearray attributes"); | ||
| return GLStatus::Error; | ||
| | ||
| case SILTensorOpInfo::OperandClass::Array: // Handled as 'normal' | ||
| case SILTensorOpInfo::OperandClass::ArrayElement: | ||
| llvm_unreachable("This is a legacy class that shouldn't happen"); | ||
| } | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -129,7 +129,6 @@ public func test75407624() { | |
| | ||
| | ||
| public func testConvolution(x: Tensor<Float>, filter: Tensor<Float>) -> Tensor<Float> { | ||
| // expected-error @+1 {{FIXME: Handle array attributes}} | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. awesome!! | ||
| return x.toAccelerator().convolved2D(withFilter: filter.toAccelerator(), | ||
| strides: (1, 2, 3, 4), padding: .same) | ||
| } | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any enum that we can match on, rather than doing (sub)string comparison?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, not really (that I know of). This really comes down to what types we want to allow in #tfop attribute lists. We could tighten this up to check that the types are coming from Swift module. I'll dig around to see what ASTContext has to work with here.