diff --git a/src/libs/protocols/.gitignore b/src/libs/protocols/.gitignore index 5b5d2f3df2..bc7e10bf31 100644 --- a/src/libs/protocols/.gitignore +++ b/src/libs/protocols/.gitignore @@ -3,3 +3,4 @@ Cargo.lock src/*.rs !src/lib.rs !src/trans.rs +!src/serde_config.rs diff --git a/src/libs/protocols/build.rs b/src/libs/protocols/build.rs index 12818b0571..1ec2da9829 100644 --- a/src/libs/protocols/build.rs +++ b/src/libs/protocols/build.rs @@ -8,7 +8,45 @@ use std::io::{BufRead, BufReader, Read, Write}; use std::path::Path; use std::process::exit; -use ttrpc_codegen::{Codegen, Customize, ProtobufCustomize}; +use protobuf::{ + descriptor::field_descriptor_proto::Type, + reflect::{EnumDescriptor, FieldDescriptor, MessageDescriptor, OneofDescriptor}, +}; +use ttrpc_codegen::{Codegen, Customize, ProtobufCustomize, ProtobufCustomizeCallback}; + +struct GenSerde; + +impl ProtobufCustomizeCallback for GenSerde { + fn message(&self, _message: &MessageDescriptor) -> ProtobufCustomize { + ProtobufCustomize::default().before("#[cfg_attr(feature = \"with-serde\", derive(::serde::Serialize, ::serde::Deserialize))]") + } + + fn enumeration(&self, _enum_type: &EnumDescriptor) -> ProtobufCustomize { + ProtobufCustomize::default().before("#[cfg_attr(feature = \"with-serde\", derive(::serde::Serialize, ::serde::Deserialize))]") + } + + fn oneof(&self, _oneof: &OneofDescriptor) -> ProtobufCustomize { + ProtobufCustomize::default().before("#[cfg_attr(feature = \"with-serde\", derive(::serde::Serialize, ::serde::Deserialize))]") + } + + fn field(&self, field: &FieldDescriptor) -> ProtobufCustomize { + if field.proto().type_() == Type::TYPE_ENUM { + ProtobufCustomize::default().before( + "#[cfg_attr(feature = \"with-serde\", serde(serialize_with = \"crate::serialize_enum_or_unknown\", deserialize_with = \"crate::deserialize_enum_or_unknown\"))]", + ) + } else if field.proto().type_() == Type::TYPE_MESSAGE && field.is_singular() { + ProtobufCustomize::default().before( + "#[cfg_attr(feature = \"with-serde\", serde(serialize_with = \"crate::serialize_message_field\", deserialize_with = \"crate::deserialize_message_field\"))]", + ) + } else { + ProtobufCustomize::default() + } + } + + fn special_field(&self, _message: &MessageDescriptor, _field: &str) -> ProtobufCustomize { + ProtobufCustomize::default().before("#[cfg_attr(feature = \"with-serde\", serde(skip))]") + } +} fn replace_text_in_file(file_name: &str, from: &str, to: &str) -> Result<(), std::io::Error> { let mut src = File::open(file_name)?; diff --git a/src/libs/protocols/src/lib.rs b/src/libs/protocols/src/lib.rs index 801b700601..33f75ca0ea 100644 --- a/src/libs/protocols/src/lib.rs +++ b/src/libs/protocols/src/lib.rs @@ -17,5 +17,13 @@ pub mod health_ttrpc; #[cfg(feature = "async")] pub mod health_ttrpc_async; pub mod oci; +#[cfg(feature = "with-serde")] +mod serde_config; pub mod trans; pub mod types; + +#[cfg(feature = "with-serde")] +pub use serde_config::{ + deserialize_enum_or_unknown, deserialize_message_field, serialize_enum_or_unknown, + serialize_message_field, +}; diff --git a/src/libs/protocols/src/serde_config.rs b/src/libs/protocols/src/serde_config.rs new file mode 100644 index 0000000000..c1a1d2b7c9 --- /dev/null +++ b/src/libs/protocols/src/serde_config.rs @@ -0,0 +1,38 @@ +// Copyright (c) 2023 Ant Group +// +// SPDX-License-Identifier: Apache-2.0 +// + +use protobuf::{EnumOrUnknown, MessageField}; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "with-serde")] +pub fn serialize_enum_or_unknown( + e: &protobuf::EnumOrUnknown, + s: S, +) -> Result { + e.value().serialize(s) +} + +pub fn serialize_message_field( + e: &protobuf::MessageField, + s: S, +) -> Result { + if e.is_some() { + e.as_ref().unwrap().serialize(s) + } else { + s.serialize_unit() + } +} + +pub fn deserialize_enum_or_unknown<'de, E: Deserialize<'de>, D: serde::Deserializer<'de>>( + d: D, +) -> Result, D::Error> { + i32::deserialize(d).map(EnumOrUnknown::from_i32) +} + +pub fn deserialize_message_field<'de, E: Deserialize<'de>, D: serde::Deserializer<'de>>( + d: D, +) -> Result, D::Error> { + Option::deserialize(d).map(MessageField::from_option) +}