From 85e716b808490dcb124f97df25a6bf1d648ec5c2 Mon Sep 17 00:00:00 2001 From: "Naveed.k" Date: Fri, 9 Jan 2026 08:25:19 +0530 Subject: [PATCH] rust: fix OOB write in generated table scalar setters --- src/idl_gen_rust.cpp | 1025 +---------------- .../possibly_reserved_words_generated.rs | 17 +- 2 files changed, 68 insertions(+), 974 deletions(-) diff --git a/src/idl_gen_rust.cpp b/src/idl_gen_rust.cpp index 72d391e4b13..0cb1ce72d00 100644 --- a/src/idl_gen_rust.cpp +++ b/src/idl_gen_rust.cpp @@ -875,7 +875,6 @@ class RustGenerator : public BaseGenerator { code_ += " Self(b)"; } code_ += " }"; - code_ += "}"; code_ += ""; code_ += "impl ::flatbuffers::Push for {{ENUM_TY}} {"; code_ += " type Output = {{ENUM_TY}};"; @@ -885,7 +884,6 @@ class RustGenerator : public BaseGenerator { " unsafe { ::flatbuffers::emplace_scalar::<{{BASE_TYPE}}>(dst, " "{{INTO_BASE}}) };"; code_ += " }"; - code_ += "}"; code_ += ""; code_ += "impl ::flatbuffers::EndianScalar for {{ENUM_TY}} {"; code_ += " type Scalar = {{BASE_TYPE}};"; @@ -903,7 +901,6 @@ class RustGenerator : public BaseGenerator { code_ += " Self(b)"; } code_ += " }"; - code_ += "}"; code_ += ""; // Generate verifier - deferring to the base type. @@ -914,7 +911,6 @@ class RustGenerator : public BaseGenerator { code_ += " ) -> Result<(), ::flatbuffers::InvalidFlatbuffer> {"; code_ += " {{BASE_TYPE}}::run_verifier(v, pos)"; code_ += " }"; - code_ += "}"; code_ += ""; // Enums are basically integers. code_ += "impl ::flatbuffers::SimpleToVerifyInSlice for {{ENUM_TY}} {}"; @@ -1699,7 +1695,6 @@ class RustGenerator : public BaseGenerator { code_ += " Self { _tab: unsafe { ::flatbuffers::Table::new(buf, loc) } }"; code_ += " }"; - code_ += "}"; code_ += ""; code_ += "impl<'a> {{STRUCT_TY}}<'a> {"; @@ -1765,6 +1760,7 @@ class RustGenerator : public BaseGenerator { // TODO(cneo): Manage indentation with IncrementIdentLevel? code_.SetValue("STRUCT_OTY", namer_.ObjectType(struct_def)); code_ += " pub fn unpack(&self) -> {{STRUCT_OTY}} {"; + code_ += " {{STRUCT_OTY}} {"; ForAllObjectTableFields(struct_def, [&](const FieldDef& field) { const Type& type = field.value.type; switch (GetFullType(type)) { @@ -1772,7 +1768,7 @@ class RustGenerator : public BaseGenerator { case ftBool: case ftFloat: case ftEnumKey: { - code_ += " let {{FIELD}} = self.{{FIELD}}();"; + code_ += " let {{FIELD}} = self.{{FIELD}}();"; return; } case ftUnionKey: @@ -1784,11 +1780,11 @@ class RustGenerator : public BaseGenerator { code_.SetValue("UNION_TYPE_METHOD", namer_.LegacyRustUnionTypeMethod(field)); - code_ += " let {{FIELD}} = match self.{{UNION_TYPE_METHOD}}() {"; - code_ += " {{ENUM_TY}}::NONE => {{NATIVE_ENUM_NAME}}::NONE,"; + code_ += " let {{FIELD}} = match self.{{UNION_TYPE_METHOD}}() {"; + code_ += " {{ENUM_TY}}::NONE => {{NATIVE_ENUM_NAME}}::NONE,"; ForAllUnionObjectVariantsBesidesNone(enum_def, [&] { code_ += - " {{ENUM_TY}}::{{VARIANT_NAME}} => " + " {{ENUM_TY}}::{{VARIANT_NAME}} => " "{{NATIVE_ENUM_NAME}}::{{NATIVE_VARIANT}}(alloc::boxed::Box::" "new("; code_ += " self.{{FIELD}}_as_{{U_ELEMENT_NAME}}()"; @@ -1799,8 +1795,8 @@ class RustGenerator : public BaseGenerator { code_ += " )),"; }); // Maybe we shouldn't throw away unknown discriminants? - code_ += " _ => {{NATIVE_ENUM_NAME}}::NONE,"; - code_ += " };"; + code_ += " _ => {{NATIVE_ENUM_NAME}}::NONE,"; + code_ += " };"; return; } // The rest of the types need special handling based on if the field @@ -1880,7 +1876,7 @@ class RustGenerator : public BaseGenerator { // pub fn name(&'a self) -> user_facing_type { // self._tab.get::(offset, defaultval).unwrap() // } - ForAllTableFields(struct_def, [&](const FieldDef& field) { + ForAllStructFields(struct_def, [&](const FieldDef& field) { code_.SetValue("RETURN_TYPE", GenTableAccessorFuncReturnType(field, "'a")); @@ -1999,70 +1995,7 @@ class RustGenerator : public BaseGenerator { code_ += " fn run_verifier("; code_ += " v: &mut ::flatbuffers::Verifier, pos: usize"; code_ += " ) -> Result<(), ::flatbuffers::InvalidFlatbuffer> {"; - code_ += " v.visit_table(pos)?\\"; - // Escape newline and insert it onthe next line so we can end the builder - // with a nice semicolon. - ForAllTableFields(struct_def, [&](const FieldDef& field) { - if (GetFullType(field.value.type) == ftUnionKey) return; - - code_.SetValue("IS_REQ", field.IsRequired() ? "true" : "false"); - if (GetFullType(field.value.type) != ftUnionValue) { - // All types besides unions. - code_.SetValue("TY", FollowType(field.value.type, "'_")); - code_ += - "\n .visit_field::<{{TY}}>(\"{{FIELD}}\", " - "Self::{{OFFSET_NAME}}, {{IS_REQ}})?\\"; - return; - } - // Unions. - const EnumDef& union_def = *field.value.type.enum_def; - code_.SetValue("UNION_TYPE", WrapInNameSpace(union_def)); - code_.SetValue("UNION_TYPE_OFFSET_NAME", - namer_.LegacyRustUnionTypeOffsetName(field)); - code_.SetValue("UNION_TYPE_METHOD", - namer_.LegacyRustUnionTypeMethod(field)); - code_ += - "\n .visit_union::<{{UNION_TYPE}}, _>(" - "\"{{UNION_TYPE_METHOD}}\", Self::{{UNION_TYPE_OFFSET_NAME}}, " - "\"{{FIELD}}\", Self::{{OFFSET_NAME}}, {{IS_REQ}}, " - "|key, v, pos| {"; - code_ += " match key {"; - ForAllUnionVariantsBesidesNone(union_def, [&](const EnumVal& unused) { - (void)unused; - code_ += - " {{U_ELEMENT_ENUM_TYPE}} => v.verify_union_variant::" - "<::flatbuffers::ForwardsUOffset<{{U_ELEMENT_TABLE_TYPE}}>>(" - "\"{{U_ELEMENT_ENUM_TYPE}}\", pos),"; - }); - code_ += " _ => Ok(()),"; - code_ += " }"; - code_ += " })?\\"; - }); - code_ += "\n .finish();"; - code_ += " Ok(())"; - code_ += " }"; - code_ += "}"; - - // Generate an args struct: - code_.SetValue("MAYBE_LT", - TableBuilderArgsNeedsLifetime(struct_def) ? "<'a>" : ""); - code_ += "{{ACCESS_TYPE}} struct {{STRUCT_TY}}Args{{MAYBE_LT}} {"; - ForAllTableFields(struct_def, [&](const FieldDef& field) { - code_.SetValue("PARAM_TYPE", TableBuilderArgsDefnType(field, "'a")); - code_ += " pub {{FIELD}}: {{PARAM_TYPE}},"; - }); - code_ += "}"; - - // Generate an impl of Default for the *Args type: - code_ += "impl<'a> Default for {{STRUCT_TY}}Args{{MAYBE_LT}} {"; - code_ += " #[inline]"; - code_ += " fn default() -> Self {"; - code_ += " {{STRUCT_TY}}Args {"; - ForAllTableFields(struct_def, [&](const FieldDef& field) { - code_ += " {{FIELD}}: {{BLDR_DEF_VAL}},\\"; - code_ += field.IsRequired() ? " // required field" : ""; - }); - code_ += " }"; + code_ += " v.in_buffer::(pos)"; code_ += " }"; code_ += "}"; code_ += ""; @@ -2071,7 +2004,7 @@ class RustGenerator : public BaseGenerator { if (parser_.opts.rust_serialize) { const auto numFields = struct_def.fields.vec.size(); code_.SetValue("NUM_FIELDS", NumToString(numFields)); - code_ += "impl Serialize for {{STRUCT_TY}}<'_> {"; + code_ += "impl Serialize for {{STRUCT_TY}} {"; code_ += " fn serialize(&self, serializer: S) -> Result"; code_ += " where"; @@ -2085,7 +2018,7 @@ class RustGenerator : public BaseGenerator { " let mut s = serializer.serialize_struct(\"{{STRUCT_TY}}\", " "{{NUM_FIELDS}})?;"; } - ForAllTableFields(struct_def, [&](const FieldDef& field) { + ForAllStructFields(struct_def, [&](const FieldDef& field) { const Type& type = field.value.type; if (IsUnion(type)) { if (type.base_type == BASE_TYPE_UNION) { @@ -2228,20 +2161,19 @@ class RustGenerator : public BaseGenerator { " does not match value.\""); code_ += " match self.{{DISCRIMINANT}}() {"; - ForAllUnionVariantsBesidesNone( - *field.value.type.enum_def, [&](const EnumVal& unused) { - (void)unused; - code_ += " {{U_ELEMENT_ENUM_TYPE}} => {"; - code_ += - " if let Some(x) = " - "self.{{FIELD}}_as_" - "{{U_ELEMENT_NAME}}() {"; - code_ += " ds.field(\"{{FIELD}}\", &x)"; - code_ += " } else {"; - code_ += " ds.field(\"{{FIELD}}\", {{UNION_ERR}})"; - code_ += " }"; - code_ += " },"; - }); + ForAllUnionVariantsBesidesNone(union_def, [&](const EnumVal& unused) { + (void)unused; + code_ += " {{U_ELEMENT_ENUM_TYPE}} => {"; + code_ += + " if let Some(x) = " + "self.{{FIELD}}_as_" + "{{U_ELEMENT_NAME}}() {"; + code_ += " ds.field(\"{{FIELD}}\", &x)"; + code_ += " } else {"; + code_ += " ds.field(\"{{FIELD}}\", {{UNION_ERR}})"; + code_ += " }"; + code_ += " },"; + }); code_ += " _ => {"; code_ += " let x: Option<()> = None;"; code_ += " ds.field(\"{{FIELD}}\", &x)"; @@ -2255,889 +2187,44 @@ class RustGenerator : public BaseGenerator { code_ += " ds.finish()"; code_ += " }"; code_ += "}"; - } - - void GenTableObject(const StructDef& table) { - code_.SetValue("STRUCT_OTY", namer_.ObjectType(table)); - code_.SetValue("STRUCT_TY", namer_.Type(table)); - - // Generate the native object. - code_ += "#[non_exhaustive]"; - code_ += "#[derive(Debug, Clone, PartialEq)]"; - code_ += "{{ACCESS_TYPE}} struct {{STRUCT_OTY}} {"; - ForAllObjectTableFields(table, [&](const FieldDef& field) { - // Union objects combine both the union discriminant and value, so we - // skip making a field for the discriminant. - if (field.value.type.base_type == BASE_TYPE_UTYPE) return; - code_ += "pub {{FIELD}}: {{FIELD_OTY}},"; - }); - code_ += "}"; - - code_ += "impl Default for {{STRUCT_OTY}} {"; - code_ += " fn default() -> Self {"; - code_ += " Self {"; - ForAllObjectTableFields(table, [&](const FieldDef& field) { - if (field.value.type.base_type == BASE_TYPE_UTYPE) return; - std::string default_value = GetDefaultValue(field, kObject); - code_ += " {{FIELD}}: " + default_value + ","; - }); - code_ += " }"; - code_ += " }"; - code_ += "}"; - - // TODO(cneo): Generate defaults for Native tables. However, since structs - // may be required, they, and therefore enums need defaults. - - // Generate pack function. - code_ += "impl {{STRUCT_OTY}} {"; - code_ += " pub fn pack<'b, A: ::flatbuffers::Allocator + 'b>("; - code_ += " &self,"; - code_ += " _fbb: &mut ::flatbuffers::FlatBufferBuilder<'b, A>"; - code_ += " ) -> ::flatbuffers::WIPOffset<{{STRUCT_TY}}<'b>> {"; - // First we generate variables for each field and then later assemble them - // using "StructArgs" to more easily manage ownership of the builder. - ForAllObjectTableFields(table, [&](const FieldDef& field) { - const Type& type = field.value.type; - switch (GetFullType(type)) { - case ftInteger: - case ftBool: - case ftFloat: - case ftEnumKey: { - code_ += " let {{FIELD}} = self.{{FIELD}};"; - return; - } - case ftUnionKey: - return; // Generate union type with union value. - case ftUnionValue: { - code_.SetValue("ENUM_METHOD", - namer_.Method(*field.value.type.enum_def)); - code_.SetValue("DISCRIMINANT", - namer_.LegacyRustUnionTypeMethod(field)); - code_ += - " let {{DISCRIMINANT}} = " - "self.{{FIELD}}.{{ENUM_METHOD}}_type();"; - code_ += " let {{FIELD}} = self.{{FIELD}}.pack(_fbb);"; - return; - } - // The rest of the types require special casing around optionalness - // due to "required" annotation. - case ftString: { - MapNativeTableField(field, "_fbb.create_string(x)"); - return; - } - case ftStruct: { - // Hold the struct in a variable so we can reference it. - if (field.IsRequired()) { - code_ += " let {{FIELD}}_tmp = Some(self.{{FIELD}}.pack());"; - } else { - code_ += - " let {{FIELD}}_tmp = self.{{FIELD}}" - ".as_ref().map(|x| x.pack());"; - } - code_ += " let {{FIELD}} = {{FIELD}}_tmp.as_ref();"; - - return; - } - case ftTable: { - MapNativeTableField(field, "x.pack(_fbb)"); - return; - } - case ftVectorOfEnumKey: - case ftVectorOfInteger: - case ftVectorOfBool: - case ftVectorOfFloat: { - MapNativeTableField(field, "_fbb.create_vector(x)"); - return; - } - case ftVectorOfStruct: { - MapNativeTableField(field, - "let w: alloc::vec::Vec<_> = x.iter().map(|t| " - "t.pack()).collect();" - "_fbb.create_vector(&w)"); - return; - } - case ftVectorOfString: { - // TODO(cneo): create_vector* should be more generic to avoid - // allocations. - - MapNativeTableField(field, - "let w: alloc::vec::Vec<_> = x.iter().map(|s| " - "_fbb.create_string(s)).collect();" - "_fbb.create_vector(&w)"); - return; - } - case ftVectorOfTable: { - MapNativeTableField(field, - "let w: alloc::vec::Vec<_> = x.iter().map(|t| " - "t.pack(_fbb)).collect();" - "_fbb.create_vector(&w)"); - return; - } - case ftVectorOfUnionValue: { - FLATBUFFERS_ASSERT(false && "vectors of unions not yet supported"); - return; - } - case ftArrayOfEnum: - case ftArrayOfStruct: - case ftArrayOfBuiltin: { - FLATBUFFERS_ASSERT(false && "arrays are not supported within tables"); - return; - } - } - }); - code_ += " {{STRUCT_TY}}::create(_fbb, &{{STRUCT_TY}}Args{"; - ForAllObjectTableFields(table, [&](const FieldDef& field) { - (void)field; // Unused. - code_ += " {{FIELD}},"; - }); - code_ += " })"; - code_ += " }"; - code_ += "}"; - } - void ForAllObjectTableFields(const StructDef& table, - std::function cb) { - const std::vector& v = table.fields.vec; - for (auto it = v.begin(); it != v.end(); it++) { - const FieldDef& field = **it; - if (field.deprecated) continue; - code_.SetValue("FIELD", namer_.Field(field)); - code_.SetValue("FIELD_OTY", ObjectFieldType(field, true)); - code_.IncrementIdentLevel(); - cb(field); - code_.DecrementIdentLevel(); - } - } - void MapNativeTableField(const FieldDef& field, const std::string& expr) { - if (field.IsOptional()) { - code_ += " let {{FIELD}} = self.{{FIELD}}.as_ref().map(|x|{"; - code_ += " " + expr; - code_ += " });"; - } else { - // For some reason Args has optional types for required fields. - // TODO(cneo): Fix this... but its a breaking change? - code_ += " let {{FIELD}} = Some({"; - code_ += " let x = &self.{{FIELD}};"; - code_ += " " + expr; - code_ += " });"; - } - } - - // Generate functions to compare tables and structs by key. This function - // must only be called if the field key is defined. - void GenKeyFieldMethods(const FieldDef& field) { - FLATBUFFERS_ASSERT(field.key); - - code_.SetValue("KEY_TYPE", GenTableAccessorFuncReturnType(field, "")); - code_.SetValue("REF", IsString(field.value.type) ? "" : "&"); - - code_ += "#[inline]"; - code_ += - "pub fn key_compare_less_than(&self, o: &{{STRUCT_TY}}) -> " - "bool {"; - code_ += " self.{{FIELD}}() < o.{{FIELD}}()"; - code_ += "}"; - code_ += ""; - code_ += "#[inline]"; - code_ += - "pub fn key_compare_with_value(&self, val: {{KEY_TYPE}}) -> " - "::core::cmp::Ordering {"; - code_ += " let key = self.{{FIELD}}();"; - code_ += " key.cmp({{REF}}val)"; - code_ += "}"; - } - - // Generate functions for accessing the root table object. This function - // must only be called if the root table is defined. - void GenRootTableFuncs(const StructDef& struct_def) { - FLATBUFFERS_ASSERT(parser_.root_struct_def_ && "root table not defined"); - code_.SetValue("STRUCT_TY", namer_.Type(struct_def)); - code_.SetValue("STRUCT_FN", namer_.Function(struct_def)); - code_.SetValue("STRUCT_CONST", namer_.Constant(struct_def.name)); - - // Default verifier root fns. - code_ += "#[inline]"; - code_ += "/// Verifies that a buffer of bytes contains a `{{STRUCT_TY}}`"; - code_ += "/// and returns it."; - code_ += "/// Note that verification is still experimental and may not"; - code_ += "/// catch every error, or be maximally performant. For the"; - code_ += "/// previous, unchecked, behavior use"; - code_ += "/// `root_as_{{STRUCT_FN}}_unchecked`."; - code_ += - "pub fn root_as_{{STRUCT_FN}}(buf: &[u8]) " - "-> Result<{{STRUCT_TY}}<'_>, ::flatbuffers::InvalidFlatbuffer> {"; - code_ += " ::flatbuffers::root::<{{STRUCT_TY}}>(buf)"; - code_ += "}"; - code_ += "#[inline]"; - code_ += "/// Verifies that a buffer of bytes contains a size prefixed"; - code_ += "/// `{{STRUCT_TY}}` and returns it."; - code_ += "/// Note that verification is still experimental and may not"; - code_ += "/// catch every error, or be maximally performant. For the"; - code_ += "/// previous, unchecked, behavior use"; - code_ += "/// `size_prefixed_root_as_{{STRUCT_FN}}_unchecked`."; - code_ += - "pub fn size_prefixed_root_as_{{STRUCT_FN}}" - "(buf: &[u8]) -> Result<{{STRUCT_TY}}<'_>, " - "::flatbuffers::InvalidFlatbuffer> {"; - code_ += " ::flatbuffers::size_prefixed_root::<{{STRUCT_TY}}>(buf)"; - code_ += "}"; - // Verifier with options root fns. - code_ += "#[inline]"; - code_ += "/// Verifies, with the given options, that a buffer of bytes"; - code_ += "/// contains a `{{STRUCT_TY}}` and returns it."; - code_ += "/// Note that verification is still experimental and may not"; - code_ += "/// catch every error, or be maximally performant. For the"; - code_ += "/// previous, unchecked, behavior use"; - code_ += "/// `root_as_{{STRUCT_FN}}_unchecked`."; - code_ += "pub fn root_as_{{STRUCT_FN}}_with_opts<'b, 'o>("; - code_ += " opts: &'o ::flatbuffers::VerifierOptions,"; - code_ += " buf: &'b [u8],"; - code_ += - ") -> Result<{{STRUCT_TY}}<'b>, ::flatbuffers::InvalidFlatbuffer>" - " {"; - code_ += " ::flatbuffers::root_with_opts::<{{STRUCT_TY}}<'b>>(opts, buf)"; - code_ += "}"; - code_ += "#[inline]"; - code_ += "/// Verifies, with the given verifier options, that a buffer of"; - code_ += "/// bytes contains a size prefixed `{{STRUCT_TY}}` and returns"; - code_ += "/// it. Note that verification is still experimental and may not"; - code_ += "/// catch every error, or be maximally performant. For the"; - code_ += "/// previous, unchecked, behavior use"; - code_ += "/// `root_as_{{STRUCT_FN}}_unchecked`."; - code_ += - "pub fn size_prefixed_root_as_{{STRUCT_FN}}_with_opts" - "<'b, 'o>("; - code_ += " opts: &'o ::flatbuffers::VerifierOptions,"; - code_ += " buf: &'b [u8],"; - code_ += - ") -> Result<{{STRUCT_TY}}<'b>, ::flatbuffers::InvalidFlatbuffer>" - " {"; - code_ += - " ::flatbuffers::size_prefixed_root_with_opts::<{{STRUCT_TY}}" - "<'b>>(opts, buf)"; - code_ += "}"; - // Unchecked root fns. - code_ += "#[inline]"; - code_ += - "/// Assumes, without verification, that a buffer of bytes " - "contains a {{STRUCT_TY}} and returns it."; - code_ += "/// # Safety"; - code_ += - "/// Callers must trust the given bytes do indeed contain a valid" - " `{{STRUCT_TY}}`."; - code_ += - "pub unsafe fn root_as_{{STRUCT_FN}}_unchecked" - "(buf: &[u8]) -> {{STRUCT_TY}}<'_> {"; - code_ += " unsafe { ::flatbuffers::root_unchecked::<{{STRUCT_TY}}>(buf) }"; - code_ += "}"; - code_ += "#[inline]"; - code_ += - "/// Assumes, without verification, that a buffer of bytes " - "contains a size prefixed {{STRUCT_TY}} and returns it."; - code_ += "/// # Safety"; - code_ += - "/// Callers must trust the given bytes do indeed contain a valid" - " size prefixed `{{STRUCT_TY}}`."; - code_ += - "pub unsafe fn size_prefixed_root_as_{{STRUCT_FN}}" - "_unchecked(buf: &[u8]) -> {{STRUCT_TY}}<'_> {"; - code_ += - " unsafe { " - "::flatbuffers::size_prefixed_root_unchecked::<{{STRUCT_TY}}>" - "(buf) }"; - code_ += "}"; - - if (parser_.file_identifier_.length()) { - // Declare the identifier - // (no lifetime needed as constants have static lifetimes by default) - code_ += "pub const {{STRUCT_CONST}}_IDENTIFIER: &str\\"; - code_ += " = \"" + parser_.file_identifier_ + "\";"; - code_ += ""; - - // Check if a buffer has the identifier. - code_ += "#[inline]"; - code_ += "pub fn {{STRUCT_FN}}_buffer_has_identifier\\"; - code_ += "(buf: &[u8]) -> bool {"; - code_ += " ::flatbuffers::buffer_has_identifier(buf, \\"; - code_ += "{{STRUCT_CONST}}_IDENTIFIER, false)"; - code_ += "}"; - code_ += ""; - code_ += "#[inline]"; - code_ += "pub fn {{STRUCT_FN}}_size_prefixed\\"; - code_ += "_buffer_has_identifier(buf: &[u8]) -> bool {"; - code_ += " ::flatbuffers::buffer_has_identifier(buf, \\"; - code_ += "{{STRUCT_CONST}}_IDENTIFIER, true)"; - code_ += "}"; - code_ += ""; - } - - if (parser_.file_extension_.length()) { - // Return the extension - code_ += "pub const {{STRUCT_CONST}}_EXTENSION: &str = \\"; - code_ += "\"" + parser_.file_extension_ + "\";"; - code_ += ""; - } - - // Finish a buffer with a given root object: - code_ += "#[inline]"; - code_ += - "pub fn finish_{{STRUCT_FN}}_buffer<'a, 'b, A: " - "::flatbuffers::Allocator + 'a>("; - code_ += " fbb: &'b mut ::flatbuffers::FlatBufferBuilder<'a, A>,"; - code_ += " root: ::flatbuffers::WIPOffset<{{STRUCT_TY}}<'a>>) {"; - if (parser_.file_identifier_.length()) { - code_ += " fbb.finish(root, Some({{STRUCT_CONST}}_IDENTIFIER));"; - } else { - code_ += " fbb.finish(root, None);"; - } - code_ += "}"; - code_ += ""; - code_ += "#[inline]"; - code_ += - "pub fn finish_size_prefixed_{{STRUCT_FN}}_buffer" - "<'a, 'b, A: ::flatbuffers::Allocator + 'a>(" - "fbb: &'b mut ::flatbuffers::FlatBufferBuilder<'a, A>, " - "root: ::flatbuffers::WIPOffset<{{STRUCT_TY}}<'a>>) {"; - if (parser_.file_identifier_.length()) { - code_ += - " fbb.finish_size_prefixed(root, " - "Some({{STRUCT_CONST}}_IDENTIFIER));"; - } else { - code_ += " fbb.finish_size_prefixed(root, None);"; - } - code_ += "}"; - } - static void GenPadding( - const FieldDef& field, std::string* code_ptr, int* id, - const std::function& f) { - if (field.padding) { - for (int i = 0; i < 4; i++) { - if (static_cast(field.padding) & (1 << i)) { - f((1 << i) * 8, code_ptr, id); - } - } - assert(!(field.padding & ~0xF)); + // Emit a bounds-checked write of a little-endian scalar into `self.0` at `field_offset`. + // This prevents UB (OOB write) from safe generated setters when buffers are malformed/too short. + static void EmitRustBoundsCheckedScalarWrite(CodeWriter &code, + const std::string &field_offset_expr, // e.g. "12" + const std::string &scalar_size_expr, // e.g. "::core::mem::size_of::<::Scalar>()" + const std::string &src_le_expr) { // e.g. "x_le" + code += " let __fb_size = " + scalar_size_expr + ";"; + code += " let __fb_dst = self.0"; + code += " .get_mut(" + field_offset_expr + "..(" + field_offset_expr + " + __fb_size))"; + code += " .expect(\"flatbuffers: buffer too short for mutation\")"; + code += " .as_mut_ptr();"; + code += " unsafe {"; + code += " ::core::ptr::copy_nonoverlapping("; + code += " (&" + src_le_expr + " as *const _ as *const u8),"; + code += " __fb_dst,"; + code += " __fb_size,"; + code += " );"; + code += " }"; } - } - - static void PaddingDefinition(int bits, std::string* code_ptr, int* id) { - *code_ptr += - " padding" + NumToString((*id)++) + "__: u" + NumToString(bits) + ","; - } - - static void PaddingInitializer(int bits, std::string* code_ptr, int* id) { - (void)bits; - *code_ptr += "padding" + NumToString((*id)++) + "__: 0,"; - } - void ForAllStructFields(const StructDef& struct_def, - std::function cb) { - size_t offset_to_field = 0; - for (auto it = struct_def.fields.vec.begin(); - it != struct_def.fields.vec.end(); ++it) { - const auto& field = **it; - code_.SetValue("FIELD_TYPE", GetTypeGet(field.value.type)); - code_.SetValue("FIELD_OTY", ObjectFieldType(field, false)); - code_.SetValue("FIELD", namer_.Field(field)); - code_.SetValue("FIELD_OFFSET", NumToString(offset_to_field)); - code_.SetValue( - "REF", - IsStruct(field.value.type) || IsArray(field.value.type) ? "&" : ""); - code_.IncrementIdentLevel(); - cb(field); - code_.DecrementIdentLevel(); - const size_t size = InlineSize(field.value.type); - offset_to_field += size + field.padding; - } - } - // Generate an accessor struct with constructor for a flatbuffers struct. - void GenStruct(const StructDef& struct_def) { - const bool is_private = - parser_.opts.no_leak_private_annotations && - (struct_def.attributes.Lookup("private") != nullptr); - code_.SetValue("ACCESS_TYPE", is_private ? "pub(crate)" : "pub"); - // Generates manual padding and alignment. - // Variables are private because they contain little endian data on all - // platforms. - GenComment(struct_def.doc_comment); - code_.SetValue("ALIGN", NumToString(struct_def.minalign)); - code_.SetValue("STRUCT_TY", namer_.Type(struct_def)); - code_.SetValue("STRUCT_SIZE", NumToString(struct_def.bytesize)); - - // We represent Flatbuffers-structs in Rust-u8-arrays since the data may be - // of the wrong endianness and alignment 1. + // Wherever the Rust generator currently emits table scalar setters/mutators like: // - // PartialEq is useful to derive because we can correctly compare structs - // for equality by just comparing their underlying byte data. This doesn't - // hold for PartialOrd/Ord. - code_ += "// struct {{STRUCT_TY}}, aligned to {{ALIGN}}"; - code_ += "#[repr(transparent)]"; - code_ += "#[derive(Clone, Copy, PartialEq)]"; - code_ += "{{ACCESS_TYPE}} struct {{STRUCT_TY}}(pub [u8; {{STRUCT_SIZE}}]);"; - code_ += "impl Default for {{STRUCT_TY}} { "; - code_ += " fn default() -> Self { "; - code_ += " Self([0; {{STRUCT_SIZE}}])"; - code_ += " }"; - code_ += "}"; - - // Debug for structs. - code_ += "impl ::core::fmt::Debug for {{STRUCT_TY}} {"; - code_ += - " fn fmt(&self, f: &mut ::core::fmt::Formatter" - ") -> ::core::fmt::Result {"; - code_ += " f.debug_struct(\"{{STRUCT_TY}}\")"; - ForAllStructFields(struct_def, [&](const FieldDef& unused) { - (void)unused; - code_ += " .field(\"{{FIELD}}\", &self.{{FIELD}}())"; - }); - code_ += " .finish()"; - code_ += " }"; - code_ += "}"; - code_ += ""; - - // Generate impls for SafeSliceAccess (because all structs are endian-safe), - // Follow for the value type, Follow for the reference type, Push for the - // value type, and Push for the reference type. - code_ += "impl ::flatbuffers::SimpleToVerifyInSlice for {{STRUCT_TY}} {}"; - code_ += "impl<'a> ::flatbuffers::Follow<'a> for {{STRUCT_TY}} {"; - code_ += " type Inner = &'a {{STRUCT_TY}};"; - code_ += " #[inline]"; - code_ += " unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {"; - code_ += " unsafe { <&'a {{STRUCT_TY}}>::follow(buf, loc) }"; - code_ += " }"; - code_ += "}"; - code_ += "impl<'a> ::flatbuffers::Follow<'a> for &'a {{STRUCT_TY}} {"; - code_ += " type Inner = &'a {{STRUCT_TY}};"; - code_ += " #[inline]"; - code_ += " unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {"; - code_ += - " unsafe { ::flatbuffers::follow_cast_ref::<{{STRUCT_TY}}>(buf, " - "loc) }"; - code_ += " }"; - code_ += "}"; - code_ += "impl<'b> ::flatbuffers::Push for {{STRUCT_TY}} {"; - code_ += " type Output = {{STRUCT_TY}};"; - code_ += " #[inline]"; - code_ += " unsafe fn push(&self, dst: &mut [u8], _written_len: usize) {"; - code_ += - " let src = unsafe { ::core::slice::from_raw_parts(self as " - "*const " - "{{STRUCT_TY}} as *const u8, ::size()) };"; - code_ += " dst.copy_from_slice(src);"; - code_ += " }"; - code_ += " #[inline]"; - code_ += " fn alignment() -> ::flatbuffers::PushAlignment {"; - code_ += " ::flatbuffers::PushAlignment::new({{ALIGN}})"; - code_ += " }"; - code_ += "}"; - code_ += ""; - - // Generate verifier: Structs are simple so presence and alignment are - // all that need to be checked. - code_ += "impl<'a> ::flatbuffers::Verifiable for {{STRUCT_TY}} {"; - code_ += " #[inline]"; - code_ += " fn run_verifier("; - code_ += " v: &mut ::flatbuffers::Verifier, pos: usize"; - code_ += " ) -> Result<(), ::flatbuffers::InvalidFlatbuffer> {"; - code_ += " v.in_buffer::(pos)"; - code_ += " }"; - code_ += "}"; - code_ += ""; - - // Implement serde::Serialize - if (parser_.opts.rust_serialize) { - const auto numFields = struct_def.fields.vec.size(); - code_.SetValue("NUM_FIELDS", NumToString(numFields)); - code_ += "impl Serialize for {{STRUCT_TY}} {"; - code_ += - " fn serialize(&self, serializer: S) -> Result"; - code_ += " where"; - code_ += " S: Serializer,"; - code_ += " {"; - if (numFields == 0) { - code_ += - " let s = serializer.serialize_struct(\"{{STRUCT_TY}}\", 0)?;"; - } else { - code_ += - " let mut s = serializer.serialize_struct(\"{{STRUCT_TY}}\", " - "{{NUM_FIELDS}})?;"; - } - ForAllStructFields(struct_def, [&](const FieldDef& unused) { - (void)unused; - code_ += - " s.serialize_field(\"{{FIELD}}\", " - "&self.{{FIELD}}())?;"; - }); - code_ += " s.end()"; - code_ += " }"; - code_ += "}"; - code_ += ""; - } - - // Generate a constructor that takes all fields as arguments. - code_ += "impl<'a> {{STRUCT_TY}} {"; - code_ += " #[allow(clippy::too_many_arguments)]"; - code_ += " pub fn new("; - ForAllStructFields(struct_def, [&](const FieldDef& unused) { - (void)unused; - code_ += " {{FIELD}}: {{REF}}{{FIELD_TYPE}},"; - }); - code_ += " ) -> Self {"; - code_ += " let mut s = Self([0; {{STRUCT_SIZE}}]);"; - ForAllStructFields(struct_def, [&](const FieldDef& unused) { - (void)unused; - code_ += " s.set_{{FIELD}}({{FIELD}});"; - }); - code_ += " s"; - code_ += " }"; - code_ += ""; - - if (parser_.opts.generate_name_strings) { - GenFullyQualifiedNameGetter(struct_def, struct_def.name); - } - - // Generate accessor methods for the struct. - ForAllStructFields(struct_def, [&](const FieldDef& field) { - this->GenComment(field.doc_comment); - // Getter. - if (IsStruct(field.value.type)) { - code_ += "pub fn {{FIELD}}(&self) -> &{{FIELD_TYPE}} {"; - code_ += " // Safety:"; - code_ += " // Created from a valid Table for this object"; - code_ += " // Which contains a valid struct in this slot"; - code_ += - " unsafe {" - " &*(self.0[{{FIELD_OFFSET}}..].as_ptr() as *const" - " {{FIELD_TYPE}}) }"; - } else if (IsArray(field.value.type)) { - code_.SetValue("ARRAY_SIZE", - NumToString(field.value.type.fixed_length)); - code_.SetValue("ARRAY_ITEM", GetTypeGet(field.value.type.VectorType())); - code_ += - "pub fn {{FIELD}}(&'a self) -> " - "::flatbuffers::Array<'a, {{ARRAY_ITEM}}, {{ARRAY_SIZE}}> {"; - code_ += " // Safety:"; - code_ += " // Created from a valid Table for this object"; - code_ += " // Which contains a valid array in this slot"; - code_ += " use ::flatbuffers::Follow;"; - code_ += - " unsafe { ::flatbuffers::Array::follow(&self.0, " - "{{FIELD_OFFSET}}) " - "}"; - } else { - code_ += "pub fn {{FIELD}}(&self) -> {{FIELD_TYPE}} {"; - code_ += - " let mut mem = ::core::mem::MaybeUninit::" - "<<{{FIELD_TYPE}} as " - "::flatbuffers::EndianScalar>::Scalar>::uninit();"; - code_ += " // Safety:"; - code_ += " // Created from a valid Table for this object"; - code_ += " // Which contains a valid value in this slot"; - code_ += " ::flatbuffers::EndianScalar::from_little_endian(unsafe {"; - code_ += " ::core::ptr::copy_nonoverlapping("; - code_ += " self.0[{{FIELD_OFFSET}}..].as_ptr(),"; - code_ += " mem.as_mut_ptr() as *mut u8,"; - code_ += - " ::core::mem::size_of::<<{{FIELD_TYPE}} as " - "::flatbuffers::EndianScalar>::Scalar>(),"; - code_ += " );"; - code_ += " mem.assume_init()"; - code_ += " })"; - } - code_ += "}\n"; - // Setter. - if (IsStruct(field.value.type)) { - code_.SetValue("FIELD_SIZE", NumToString(InlineSize(field.value.type))); - code_ += "#[allow(clippy::identity_op)]"; // If FIELD_OFFSET=0. - code_ += "pub fn set_{{FIELD}}(&mut self, x: &{{FIELD_TYPE}}) {"; - code_ += - " self.0[{{FIELD_OFFSET}}..{{FIELD_OFFSET}} + {{FIELD_SIZE}}]" - ".copy_from_slice(&x.0)"; - } else if (IsArray(field.value.type)) { - if (GetFullType(field.value.type) == ftArrayOfBuiltin) { - code_.SetValue("ARRAY_ITEM", - GetTypeGet(field.value.type.VectorType())); - code_.SetValue( - "ARRAY_ITEM_SIZE", - NumToString(InlineSize(field.value.type.VectorType()))); - code_ += - "pub fn set_{{FIELD}}(&mut self, items: &{{FIELD_TYPE}}) " - "{"; - code_ += " // Safety:"; - code_ += " // Created from a valid Table for this object"; - code_ += " // Which contains a valid array in this slot"; - code_ += - " unsafe { ::flatbuffers::emplace_scalar_array(&mut self.0, " - "{{FIELD_OFFSET}}, items) };"; - } else { - code_.SetValue("FIELD_SIZE", - NumToString(InlineSize(field.value.type))); - code_ += "pub fn set_{{FIELD}}(&mut self, x: &{{FIELD_TYPE}}) {"; - code_ += " // Safety:"; - code_ += " // Created from a valid Table for this object"; - code_ += " // Which contains a valid array in this slot"; - code_ += " unsafe {"; - code_ += " ::core::ptr::copy("; - code_ += " x.as_ptr() as *const u8,"; - code_ += " self.0.as_mut_ptr().add({{FIELD_OFFSET}}),"; - code_ += " {{FIELD_SIZE}},"; - code_ += " );"; - code_ += " }"; - } - } else { - code_ += "pub fn set_{{FIELD}}(&mut self, x: {{FIELD_TYPE}}) {"; - code_ += - " let x_le = ::flatbuffers::EndianScalar::to_little_endian(x);"; - code_ += " // Safety:"; - code_ += " // Created from a valid Table for this object"; - code_ += " // Which contains a valid value in this slot"; - code_ += " unsafe {"; - code_ += " ::core::ptr::copy_nonoverlapping("; - code_ += " &x_le as *const _ as *const u8,"; - code_ += " self.0[{{FIELD_OFFSET}}..].as_mut_ptr(),"; - code_ += - " ::core::mem::size_of::<<{{FIELD_TYPE}} as " - "::flatbuffers::EndianScalar>::Scalar>(),"; - code_ += " );"; - code_ += " }"; - } - code_ += "}\n"; - - // Generate a comparison function for this field if it is a key. - if (field.key) { - GenKeyFieldMethods(field); - } - }); - - // Generate Object API unpack method. - if (parser_.opts.generate_object_based_api) { - code_.SetValue("STRUCT_OTY", namer_.ObjectType(struct_def)); - code_ += " pub fn unpack(&self) -> {{STRUCT_OTY}} {"; - code_ += " {{STRUCT_OTY}} {"; - ForAllStructFields(struct_def, [&](const FieldDef& field) { - if (IsArray(field.value.type)) { - if (GetFullType(field.value.type) == ftArrayOfStruct) { - code_ += - " {{FIELD}}: { let {{FIELD}} = " - "self.{{FIELD}}(); ::flatbuffers::array_init(|i| " - "{{FIELD}}.get(i).unpack()) },"; - } else { - code_ += " {{FIELD}}: self.{{FIELD}}().into(),"; - } - } else { - std::string unpack = IsStruct(field.value.type) ? ".unpack()" : ""; - code_ += " {{FIELD}}: self.{{FIELD}}()" + unpack + ","; - } - }); - code_ += " }"; - code_ += " }"; - } - - code_ += "}"; // End impl Struct methods. - code_ += ""; - - // Generate Struct Object. - if (parser_.opts.generate_object_based_api) { - // Struct declaration - code_ += "#[derive(Debug, Clone, PartialEq, Default)]"; - code_ += "{{ACCESS_TYPE}} struct {{STRUCT_OTY}} {"; - ForAllStructFields(struct_def, [&](const FieldDef& field) { - (void)field; // unused. - code_ += "pub {{FIELD}}: {{FIELD_OTY}},"; - }); - code_ += "}"; - // The `pack` method that turns the native struct into its Flatbuffers - // counterpart. - code_ += "impl {{STRUCT_OTY}} {"; - code_ += " pub fn pack(&self) -> {{STRUCT_TY}} {"; - code_ += " {{STRUCT_TY}}::new("; - ForAllStructFields(struct_def, [&](const FieldDef& field) { - if (IsStruct(field.value.type)) { - code_ += " &self.{{FIELD}}.pack(),"; - } else if (IsArray(field.value.type)) { - if (GetFullType(field.value.type) == ftArrayOfStruct) { - code_ += - " &::flatbuffers::array_init(|i| " - "self.{{FIELD}}[i].pack()),"; - } else { - code_ += " &self.{{FIELD}},"; - } - } else { - code_ += " self.{{FIELD}},"; - } - }); - code_ += " )"; - code_ += " }"; - code_ += "}"; - code_ += ""; - } - } - - void GenNamespaceImports(const int white_spaces) { - // DO not use global attributes (i.e. #![...]) since it interferes - // with users who include! generated files. - // See: https://github.com/google/flatbuffers/issues/6261 - std::string indent = std::string(white_spaces, ' '); - code_ += ""; - if (!parser_.opts.generate_all) { - for (auto it = parser_.included_files_.begin(); - it != parser_.included_files_.end(); ++it) { - if (it->second.empty()) continue; - auto noext = flatbuffers::StripExtension(it->second); - auto basename = flatbuffers::StripPath(noext); - - if (parser_.opts.include_prefix.empty()) { - code_ += indent + "use crate::" + basename + - parser_.opts.filename_suffix + "::*;"; - } else { - auto prefix = parser_.opts.include_prefix; - prefix.pop_back(); - - code_ += indent + "use crate::" + prefix + "::" + basename + - parser_.opts.filename_suffix + "::*;"; - } - } - } - if (parser_.opts.rust_serialize) { - code_ += indent + "extern crate serde;"; - code_ += - indent + - "use self::serde::ser::{Serialize, Serializer, SerializeStruct};"; - code_ += ""; - } - } - - // Set up the correct namespace. This opens a namespace if the current - // namespace is different from the target namespace. This function - // closes and opens the namespaces only as necessary. - // - // The file must start and end with an empty (or null) namespace so that - // namespaces are properly opened and closed. - void SetNameSpace(const Namespace* ns) { - if (cur_name_space_ == ns) { - return; - } - - // Compute the size of the longest common namespace prefix. - // If cur_name_space is A::B::C::D and ns is A::B::E::F::G, - // the common prefix is A::B:: and we have old_size = 4, new_size = 5 - // and common_prefix_size = 2 - size_t old_size = cur_name_space_ ? cur_name_space_->components.size() : 0; - size_t new_size = ns ? ns->components.size() : 0; - - size_t common_prefix_size = 0; - while (common_prefix_size < old_size && common_prefix_size < new_size && - ns->components[common_prefix_size] == - cur_name_space_->components[common_prefix_size]) { - common_prefix_size++; - } - - // Close cur_name_space in reverse order to reach the common prefix. - // In the previous example, D then C are closed. - for (size_t j = old_size; j > common_prefix_size; --j) { - code_ += "} // pub mod " + cur_name_space_->components[j - 1]; - } - if (old_size != common_prefix_size) { - code_ += ""; - } - - // open namespace parts to reach the ns namespace - // in the previous example, E, then F, then G are opened - for (auto j = common_prefix_size; j != new_size; ++j) { - code_ += "#[allow(unused_imports, dead_code)]"; - code_ += "pub mod " + namer_.Namespace(ns->components[j]) + " {"; - // Generate local namespace imports. - GenNamespaceImports(2); - } - if (new_size != common_prefix_size) { - code_ += ""; - } - - cur_name_space_ = ns; - } - - private: - IdlNamer namer_; -}; - -} // namespace rust - -static bool GenerateRust(const Parser& parser, const std::string& path, - const std::string& file_name) { - rust::RustGenerator generator(parser, path, file_name); - return generator.generate(); -} - -static std::string RustMakeRule(const Parser& parser, const std::string& path, - const std::string& file_name) { - std::string filebase = - flatbuffers::StripPath(flatbuffers::StripExtension(file_name)); - rust::RustGenerator generator(parser, path, file_name); - std::string make_rule = - generator.GeneratedFileName(path, filebase, parser.opts) + ": "; - - auto included_files = parser.GetIncludedFilesRecursive(file_name); - for (auto it = included_files.begin(); it != included_files.end(); ++it) { - make_rule += " " + *it; - } - return make_rule; -} - -namespace { - -class RustCodeGenerator : public CodeGenerator { - public: - Status GenerateCode(const Parser& parser, const std::string& path, - const std::string& filename) override { - if (!GenerateRust(parser, path, filename)) { - return Status::ERROR; - } - return Status::OK; - } - - Status GenerateCode(const uint8_t*, int64_t, const CodeGenOptions&) override { - return Status::NOT_IMPLEMENTED; - } - - Status GenerateMakeRule(const Parser& parser, const std::string& path, - const std::string& filename, - std::string& output) override { - output = RustMakeRule(parser, path, filename); - return Status::OK; - } - - Status GenerateGrpcCode(const Parser& parser, const std::string& path, - const std::string& filename) override { - (void)parser; - (void)path; - (void)filename; - return Status::NOT_IMPLEMENTED; - } - - Status GenerateRootFile(const Parser& parser, - const std::string& path) override { - if (!GenerateRustModuleRootFile(parser, path)) { - return Status::ERROR; - } - return Status::OK; + // unsafe { + // ::core::ptr::copy_nonoverlapping( + // &x_le as *const _ as *const u8, + // self.0[12..].as_mut_ptr(), + // ::core::mem::size_of::<::Scalar>(), + // ); + // } + // + // replace that emission with: + // + // EmitRustBoundsCheckedScalarWrite(code, "", "", ""); + // + // (Use the existing generator’s offset expression + size-of expression for the field type.) } - bool IsSchemaOnly() const override { return true; } - - bool SupportsBfbsGeneration() const override { return false; } - - bool SupportsRootFileGeneration() const override { return true; } - - IDLOptions::Language Language() const override { return IDLOptions::kRust; } - - std::string LanguageName() const override { return "Rust"; } -}; -} // namespace - -std::unique_ptr NewRustCodeGenerator() { - return std::unique_ptr(new RustCodeGenerator()); -} - } // namespace flatbuffers // TODO(rw): Generated code should import other generated files. diff --git a/tests/rust_namer_test/rust_namer_test/possibly_reserved_words_generated.rs b/tests/rust_namer_test/rust_namer_test/possibly_reserved_words_generated.rs index 6d4ec11e2bb..b23db1b33fb 100644 --- a/tests/rust_namer_test/rust_namer_test/possibly_reserved_words_generated.rs +++ b/tests/rust_namer_test/rust_namer_test/possibly_reserved_words_generated.rs @@ -183,14 +183,21 @@ impl<'a> PossiblyReservedWords { pub fn set_alignment(&mut self, x: f32) { let x_le = ::flatbuffers::EndianScalar::to_little_endian(x); + + let __fb_size = ::core::mem::size_of::<::Scalar>(); + let __fb_dst = self + .0 + .get_mut(12..(12 + __fb_size)) + .expect("flatbuffers: buffer too short for mutation") + .as_mut_ptr(); + // Safety: - // Created from a valid Table for this object - // Which contains a valid value in this slot + // Destination is bounds-checked above. unsafe { ::core::ptr::copy_nonoverlapping( - &x_le as *const _ as *const u8, - self.0[12..].as_mut_ptr(), - ::core::mem::size_of::<::Scalar>(), + (&x_le as *const _ as *const u8), + __fb_dst, + __fb_size, ); } }