Skip to content

Commit 6175bfa

Browse files
authored
Merge pull request aeron-io#569 from ZackPierce/rust_const_enums
Handle IR changes to const enums for Rust; re-enable regression tests
2 parents e4a26c7 + 454a117 commit 6175bfa

File tree

5 files changed

+127
-23
lines changed

5 files changed

+127
-23
lines changed

rust/car_example/src/main.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ impl std::convert::From<CodecErr> for IoError {
4646
fn decode_car_and_assert_expected_content(buffer: &[u8]) -> CodecResult<()> {
4747
let (h, dec_fields) = start_decoding_car(&buffer).header()?;
4848
assert_eq!(49u16, h.block_length);
49+
assert_eq!(h.block_length as usize, ::std::mem::size_of::<CarFields>());
4950
assert_eq!(1u16, h.template_id);
5051
assert_eq!(1u16, h.schema_id);
5152
assert_eq!(0u16, h.version);
@@ -67,6 +68,11 @@ fn decode_car_and_assert_expected_content(buffer: &[u8]) -> CodecResult<()> {
6768
assert!(fields.extras.get_cruise_control());
6869
assert!(fields.extras.get_sports_pack());
6970
assert!(!fields.extras.get_sun_roof());
71+
assert_eq!(2000, fields.engine.capacity);
72+
assert_eq!(4, fields.engine.num_cylinders);
73+
assert_eq!(BoostType::NITROUS, fields.engine.booster.boost_type);
74+
assert_eq!(200, fields.engine.booster.horse_power);
75+
println!("Static-length fields all match the expected values");
7076

7177
let dec_perf_figures_header = match dec_fuel_figures_header.fuel_figures_individually()? {
7278
Either::Left(mut dec_ff_members) => {

sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/NamedToken.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,16 @@ public Token typeToken()
4343
return typeToken;
4444
}
4545

46-
public static List<NamedToken> gatherNamedFieldTokens(final List<Token> fields)
46+
public static List<NamedToken> gatherNamedNonConstantFieldTokens(final List<Token> fields)
4747
{
4848
final List<NamedToken> namedTokens = new ArrayList<>();
49-
forEachField(fields, (f, t) -> namedTokens.add(new NamedToken(f.name(), t)));
49+
forEachField(fields, (f, t) ->
50+
{
51+
if (!f.isConstantEncoding())
52+
{
53+
namedTokens.add(new NamedToken(f.name(), t));
54+
}
55+
});
5056

5157
return namedTokens;
5258
}

sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,7 @@ private static Optional<FieldsRepresentationSummary> generateFieldsRepresentatio
133133
final MessageComponents components,
134134
final OutputManager outputManager) throws IOException
135135
{
136-
final List<NamedToken> namedFieldTokens = NamedToken.gatherNamedFieldTokens(components.fields);
137-
if (namedFieldTokens.isEmpty())
138-
{
139-
return Optional.empty();
140-
}
136+
final List<NamedToken> namedFieldTokens = NamedToken.gatherNamedNonConstantFieldTokens(components.fields);
141137

142138
final String representationStruct = messageTypeName + "Fields";
143139
try (Writer writer = outputManager.createOutput(messageTypeName + " Fixed-size Fields"))
@@ -149,12 +145,36 @@ private static Optional<FieldsRepresentationSummary> generateFieldsRepresentatio
149145
generateConstantAccessorImpl(writer, representationStruct, components.fields);
150146
}
151147

152-
final int numBytes = components.fields.stream()
153-
.filter((t) -> !t.isConstantEncoding())
154-
.filter((t) -> t.signal() == ENCODING || t.signal() == BEGIN_ENUM || t.signal() == BEGIN_SET)
155-
.mapToInt(Token::encodedLength)
156-
.sum();
157-
148+
// Compute the total static size in bytes of the fields representation
149+
int numBytes = 0;
150+
for (int i = 0, size = components.fields.size(); i < size;)
151+
{
152+
final Token fieldToken = components.fields.get(i);
153+
if (fieldToken.signal() == Signal.BEGIN_FIELD)
154+
{
155+
final int fieldEnd = i + fieldToken.componentTokenCount();
156+
if (!fieldToken.isConstantEncoding())
157+
{
158+
for (int j = i; j < fieldEnd; j++)
159+
{
160+
final Token t = components.fields.get(j);
161+
if (t.isConstantEncoding())
162+
{
163+
continue;
164+
}
165+
if (t.signal() == ENCODING || t.signal() == BEGIN_ENUM || t.signal() == BEGIN_SET)
166+
{
167+
numBytes += t.encodedLength();
168+
}
169+
}
170+
}
171+
i += fieldToken.componentTokenCount();
172+
}
173+
else
174+
{
175+
throw new IllegalStateException("field tokens must include bounding BEGIN_FIELD and END_FIELD tokens");
176+
}
177+
}
158178
return Optional.of(new FieldsRepresentationSummary(representationStruct, numBytes));
159179
}
160180

@@ -847,7 +867,7 @@ static class GroupTreeNode
847867
this.blockLengthType = blockLengthType;
848868
this.blockLength = blockLength;
849869
this.rawFields = fields;
850-
this.simpleNamedFields = NamedToken.gatherNamedFieldTokens(fields);
870+
this.simpleNamedFields = NamedToken.gatherNamedNonConstantFieldTokens(fields);
851871
this.varData = varData;
852872

853873
parent.ifPresent((p) -> p.addChild(this));
@@ -927,7 +947,7 @@ String generateVarDataEncoder(
927947
indent(writer, 3).append("return Err(CodecErr::SliceIsLongerThanAllowedBySchema)\n");
928948
indent(writer, 2).append("}\n");
929949
indent(writer, 2).append("// Write data length\n");
930-
indent(writer, 2, "%s.write_type::<%s>(&(l as %s), %s); // group length\n",
950+
indent(writer, 2, "%s.write_type::<%s>(&(l as %s), %s)?; // group length\n",
931951
toScratchChain(groupDepth), rustTypeName(this.lengthType), rustTypeName(this.lengthType),
932952
this.lengthType.size());
933953
indent(writer, 2).append(format("%s.write_slice_without_count::<%s>(s, %s)?;\n",
@@ -1548,22 +1568,27 @@ private static void generateConstantAccessorImpl(
15481568

15491569
case BEGIN_ENUM:
15501570
final String enumType = formatTypeName(signalToken.applicableTypeName());
1551-
String enumValue = null;
1571+
final String rawConstValueName = fieldToken.encoding().constValue().toString();
1572+
final int indexOfDot = rawConstValueName.indexOf('.');
1573+
final String constValueName = -1 == indexOfDot ?
1574+
rawConstValueName : rawConstValueName.substring(indexOfDot + 1);
1575+
boolean foundMatchingValueName = false;
15521576
for (int j = i; j < unfilteredFields.size(); j++)
15531577
{
15541578
final Token searchAhead = unfilteredFields.get(j);
1555-
if (searchAhead.signal() == VALID_VALUE)
1579+
if (searchAhead.signal() == VALID_VALUE && searchAhead.name().equals(constValueName))
15561580
{
1557-
enumValue = searchAhead.name();
1581+
foundMatchingValueName = true;
15581582
break;
15591583
}
15601584
}
1561-
if (enumValue == null)
1585+
if (!foundMatchingValueName)
15621586
{
1563-
throw new IllegalStateException("Found a constant enum field with incomplete token content");
1587+
throw new IllegalStateException(format("Found a constant enum field that requested value %s, " +
1588+
"which is not an available enum option.", rawConstValueName));
15641589
}
15651590
constantRustTypeName = enumType;
1566-
constantRustExpression = enumType + "::" + enumValue;
1591+
constantRustExpression = enumType + "::" + constValueName;
15671592
break;
15681593

15691594
case BEGIN_SET:

sbe-tool/src/test/java/uk/co/real_logic/sbe/generation/rust/RustGeneratorTest.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ public void fullGenerateBroadUseCase() throws IOException, InterruptedException
133133
{
134134
final String generatedRust = fullGenerateForResource(outputManager, "example-schema");
135135
assertContainsSharedImports(generatedRust);
136+
assertContains(generatedRust,
137+
"pub fn car_fields(mut self) -> CodecResult<(&'d CarFields, CarFuelFiguresHeaderDecoder<'d>)> {\n" +
138+
" let v = self.scratch.read_type::<CarFields>(49)?;\n" +
139+
" Ok((v, CarFuelFiguresHeaderDecoder::wrap(self.scratch)))\n" +
140+
" }");
136141
final String expectedBooleanTypeDeclaration =
137142
"#[derive(Clone,Copy,Debug,PartialEq,Eq,PartialOrd,Ord,Hash)]\n" +
138143
"#[repr(u8)]\n" +
@@ -236,7 +241,6 @@ private void assertSchemaInterpretableAsRust(final String localResourceSchema)
236241
assertRustBuildable(rust, Optional.of(localResourceSchema));
237242
}
238243

239-
@Ignore
240244
@Test
241245
public void checkValidRustFromAllExampleSchema() throws IOException, InterruptedException
242246
{
@@ -267,7 +271,36 @@ public void checkValidRustFromAllExampleSchema() throws IOException, Interrupted
267271
}
268272
}
269273

270-
@Ignore
274+
@Test
275+
public void constantEnumFields() throws IOException, InterruptedException
276+
{
277+
final String rust = fullGenerateForResource(outputManager, "constant-enum-fields");
278+
assertContainsSharedImports(rust);
279+
final String expectedCharTypeDeclaration =
280+
"#[derive(Clone,Copy,Debug,PartialEq,Eq,PartialOrd,Ord,Hash)]\n" +
281+
"#[repr(i8)]\n" +
282+
"pub enum Model {\n" +
283+
" A = 65i8,\n" +
284+
" B = 66i8,\n" +
285+
" C = 67i8,\n" +
286+
"}\n";
287+
assertContains(rust, expectedCharTypeDeclaration);
288+
assertContains(rust, "pub struct ConstantEnumsFields {\n}");
289+
assertContains(rust, "impl ConstantEnumsFields {");
290+
assertContains(rust, " pub fn c() -> Model {\n" +
291+
" Model::C\n }");
292+
assertContains(rust, "impl ConstantEnumsFMember {");
293+
assertContains(rust, " pub fn k() -> Model {\n" +
294+
" Model::C\n }");
295+
assertContains(rust,
296+
"pub fn constant_enums_fields(mut self) -> " +
297+
"CodecResult<(&'d ConstantEnumsFields, ConstantEnumsFHeaderDecoder<'d>)> {\n" +
298+
" let v = self.scratch.read_type::<ConstantEnumsFields>(0)?;\n" +
299+
" Ok((v, ConstantEnumsFHeaderDecoder::wrap(self.scratch)))\n" +
300+
" }");
301+
assertRustBuildable(rust, Optional.of("constant-enum-fields"));
302+
}
303+
271304
@Test
272305
public void constantFieldsCase() throws IOException, InterruptedException
273306
{
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
2+
<sbe:messageSchema xmlns:sbe="http://fixprotocol.io/2016/sbe"
3+
package="baseline"
4+
id="1"
5+
version="0"
6+
semanticVersion="5.2"
7+
description="Enum Constants as top-level and group field"
8+
byteOrder="littleEndian">
9+
<types>
10+
<composite name="messageHeader" description="Message identifiers and length of message root">
11+
<type name="blockLength" primitiveType="uint16"/>
12+
<type name="templateId" primitiveType="uint16"/>
13+
<type name="schemaId" primitiveType="uint16"/>
14+
<type name="version" primitiveType="uint16"/>
15+
</composite>
16+
<composite name="groupSizeEncoding" description="Repeating group dimensions">
17+
<type name="blockLength" primitiveType="uint16"/>
18+
<type name="numInGroup" primitiveType="uint16"/>
19+
</composite>
20+
</types>
21+
<types>
22+
<enum name="Model" encodingType="char">
23+
<validValue name="A">A</validValue>
24+
<validValue name="B">B</validValue>
25+
<validValue name="C">C</validValue>
26+
</enum>
27+
</types>
28+
<sbe:message name="ConstantEnums" id="1" description="">
29+
<field name="c" id="4" type="Model" presence="constant" valueRef="Model.C"/>
30+
<group name="f" id="7" dimensionType="groupSizeEncoding">
31+
<field name="k" id="12" type="Model" presence="constant" valueRef="Model.C"/>
32+
</group>
33+
</sbe:message>
34+
</sbe:messageSchema>

0 commit comments

Comments
 (0)