Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ PROTOBUF_SHARED = [
"codegen_traits.rs",
"cord.rs",
"enum.rs",
"extension.rs",
"internal.rs",
"map.rs",
"optional.rs",
Expand Down
73 changes: 70 additions & 3 deletions rust/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,71 @@
use crate::__internal::{Enum, MatcherEq, Private, SealedInternal};
use crate::{
AsMut, AsView, Clear, ClearAndParse, CopyFrom, IntoProxied, Map, MapIter, MapMut, MapValue,
MapView, MergeFrom, Message, MessageMutInterop, Mut, MutProxied, ParseError, ProtoBytes,
ProtoStr, ProtoString, Proxied, Repeated, RepeatedMut, RepeatedView, Serialize, SerializeError,
Singular, TakeFrom, View,
MapView, MergeFrom, Message, MessageMutInterop, MessageViewInterop, Mut, MutProxied,
OwnedMessageInterop, ParseError, ProtoBytes, ProtoStr, ProtoString, Proxied, Repeated,
RepeatedMut, RepeatedView, Serialize, SerializeError, Singular, TakeFrom, View,
};

// Internal helper to get extension.
// This is called by Extension::get.
pub fn get_extension_message<'a, M, T>(m: View<'a, M>, number: u32) -> Option<View<'a, T>>
where
M: Message,
T: Message + Default,
View<'a, M>: MessageViewInterop<'a>,
View<'a, T>: MessageViewInterop<'a>,
{
// SAFETY: m.__unstable_as_raw_message() is valid.
let raw_ptr = m.__unstable_as_raw_message();
// RawMessage is NonNull<...>, and FFI expects it.
// We trust that msg is valid per View invariants.
let raw_msg = unsafe { std::ptr::NonNull::new_unchecked(raw_ptr as *mut _) };
let raw_ext = unsafe { proto2_rust_Message_GetExtension(raw_msg, number as std::ffi::c_int) };
if raw_ext.is_none() {
return None;
}
let raw_ext = raw_ext.unwrap();
// SAFETY: raw_ext is returned by GetExtension from the message m.
// Its lifetime is tied to 'a (lifetime of m).
// The C++ reflection API guarantees that the returned reference is valid as long as the message is.
unsafe {
Some(View::<'a, T>::__unstable_wrap_raw_message_unchecked_lifetime(
raw_ext.as_ptr() as *const _
))
}
}

pub fn has_extension_message<'a, M, T>(m: View<'a, M>, number: u32) -> bool
where
M: Message,
T: Message + Default,
View<'a, M>: MessageViewInterop<'a>,
{
let raw_ptr = m.__unstable_as_raw_message();
// View's raw ptr is *const c_void. RawMessage is NonNull<RawMessageData>.
// Casting is required.
let raw_msg = unsafe { RawMessage::new_unchecked(raw_ptr as *mut _) };
unsafe { proto2_rust_Message_HasExtension(raw_msg, number as c_int) }
}

pub fn set_extension_message<'a, M, T>(mut m: Mut<'a, M>, number: u32, val: T)
where
M: Message,
T: Message + Default,
Mut<'a, M>: MessageMutInterop<'a>,
for<'b> View<'b, T>: MessageViewInterop<'b>,
{
let raw_ptr = m.__unstable_as_raw_message_mut();
let raw_msg = unsafe { RawMessage::new_unchecked(raw_ptr as *mut _) };
let mut_ext = unsafe { proto2_rust_Message_GetMutableExtension(raw_msg, number as c_int) };
if let Some(mut_ext) = mut_ext {
let val_view = val.as_view();
let val_raw = val_view.__unstable_as_raw_message();
let val_raw_msg = unsafe { RawMessage::new_unchecked(val_raw as *mut _) };
// Use copy_from to clear and merge (set semantics).
unsafe { proto2_rust_Message_copy_from(mut_ext, val_raw_msg) };
}
}
use core::fmt::Debug;
use paste::paste;
use std::convert::identity;
Expand Down Expand Up @@ -119,6 +180,12 @@ unsafe extern "C" {
pub fn proto2_rust_Message_copy_from(dst: RawMessage, src: RawMessage);
pub fn proto2_rust_Message_merge_from(dst: RawMessage, src: RawMessage);
pub fn proto2_rust_Message_get_descriptor(m: RawMessage) -> *const std::ffi::c_void;
pub fn proto2_rust_Message_GetExtension(m: RawMessage, number: c_int) -> Option<RawMessage>;
pub fn proto2_rust_Message_HasExtension(m: RawMessage, number: c_int) -> bool;
pub fn proto2_rust_Message_GetMutableExtension(
m: RawMessage,
number: c_int,
) -> Option<RawMessage>;
}

impl Drop for InnerProtoString {
Expand Down
88 changes: 88 additions & 0 deletions rust/cpp_kernel/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <limits>

#include "google/protobuf/descriptor.h"
#include "google/protobuf/message_lite.h"
#include "rust/cpp_kernel/serialized_data.h"
#include "rust/cpp_kernel/strings.h"
Expand Down Expand Up @@ -61,4 +62,91 @@ const void* proto2_rust_Message_get_descriptor(const google::protobuf::MessageLi
return nullptr;
}

const void* proto2_rust_Message_GetExtension(const google::protobuf::MessageLite* msg,
int number) {
if constexpr (kHasFullRuntime) {
auto m = google::protobuf::DynamicCastMessage<google::protobuf::Message>(msg);
if (m == nullptr) {
return nullptr;
}
const google::protobuf::Reflection* reflection = m->GetReflection();
const google::protobuf::Descriptor* descriptor = m->GetDescriptor();
if (descriptor == nullptr) {
return nullptr;
}
const google::protobuf::DescriptorPool* pool = descriptor->file()->pool();
const google::protobuf::FieldDescriptor* field =
pool->FindExtensionByNumber(descriptor, number);
if (field == nullptr) {
return nullptr;
}

// For now, we only support singular message extensions.
if (field->is_repeated() ||
field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
return nullptr;
}

if (!reflection->HasField(*m, field)) {
return nullptr;
}

const google::protobuf::Message& extension_msg = reflection->GetMessage(*m, field);
return &extension_msg;
}
return nullptr;
}

bool proto2_rust_Message_HasExtension(const google::protobuf::MessageLite* msg,
int number) {
if constexpr (kHasFullRuntime) {
auto m = google::protobuf::DynamicCastMessage<google::protobuf::Message>(msg);
if (m == nullptr) {
return false;
}
const google::protobuf::Reflection* reflection = m->GetReflection();
const google::protobuf::Descriptor* descriptor = m->GetDescriptor();
if (descriptor == nullptr) {
return false;
}
const google::protobuf::DescriptorPool* pool = descriptor->file()->pool();
const google::protobuf::FieldDescriptor* field =
pool->FindExtensionByNumber(descriptor, number);
if (field == nullptr) {
return false;
}
return reflection->HasField(*m, field);
}
return false;
}

void* proto2_rust_Message_GetMutableExtension(google::protobuf::MessageLite* msg,
int number) {
if constexpr (kHasFullRuntime) {
auto m = google::protobuf::DynamicCastMessage<google::protobuf::Message>(msg);
if (m == nullptr) {
return nullptr;
}
const google::protobuf::Reflection* reflection = m->GetReflection();
const google::protobuf::Descriptor* descriptor = m->GetDescriptor();
if (descriptor == nullptr) {
return nullptr;
}
const google::protobuf::DescriptorPool* pool = descriptor->file()->pool();
const google::protobuf::FieldDescriptor* field =
pool->FindExtensionByNumber(descriptor, number);
if (field == nullptr) {
return nullptr;
}

if (field->is_repeated() ||
field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
return nullptr;
}

return reflection->MutableMessage(m, field);
}
return nullptr;
}

} // extern "C"
81 changes: 81 additions & 0 deletions rust/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use std::marker::PhantomData;

/// An Extension to a protobuf message.
#[derive(Debug, Clone, Copy)]
pub struct Extension<M, T> {
pub(crate) number: u32,
pub(crate) _phantom: PhantomData<(M, T)>,
}

// TODO: This should be generated by the code generator.
impl<M, T> Extension<M, T>
where
M: crate::Message,
T: crate::Message + Default,
{
pub const fn new(number: u32) -> Self {
Self { number, _phantom: PhantomData }
}
}

#[cfg(cpp_kernel)]
impl<M, T> Extension<M, T>
where
M: crate::Message,
T: crate::Message + Default,
{
pub fn get<'a>(&self, m: crate::View<'a, M>) -> Option<crate::View<'a, T>>
where
crate::View<'a, M>: crate::MessageViewInterop<'a>,
<T as crate::Proxied>::View<'a>: crate::MessageViewInterop<'a>,
{
crate::__internal::runtime::get_extension_message::<M, T>(m, self.number)
}

pub fn has<'a>(&self, m: crate::View<'a, M>) -> bool
where
crate::View<'a, M>: crate::MessageViewInterop<'a>,
{
crate::__internal::runtime::has_extension_message::<M, T>(m, self.number)
}

pub fn set<'a>(&self, m: crate::Mut<'a, M>, val: T)
where
crate::Mut<'a, M>: crate::MessageMutInterop<'a>,
for<'b> <T as crate::Proxied>::View<'b>: crate::MessageViewInterop<'b>,
{
crate::__internal::runtime::set_extension_message::<M, T>(m, self.number, val)
}
}

#[cfg(upb_kernel)]
impl<M, T> Extension<M, T>
where
M: crate::Message + crate::__internal::runtime::AssociatedMiniTable,
T: crate::Message + Default + crate::__internal::runtime::AssociatedMiniTable,
{
pub fn get<'a>(&self, m: crate::View<'a, M>) -> Option<crate::View<'a, T>>
where
crate::View<'a, M>: crate::MessageViewInterop<'a>,
<T as crate::Proxied>::View<'a>: crate::MessageViewInterop<'a>,
{
crate::__internal::runtime::get_extension_message::<M, T>(m, self.number)
}

pub fn has<'a>(&self, m: crate::View<'a, M>) -> bool
where
crate::View<'a, M>: crate::MessageViewInterop<'a>,
{
crate::__internal::runtime::has_extension_message::<M, T>(m, self.number)
}

pub fn set<'a>(&self, m: crate::Mut<'a, M>, val: T)
where
crate::Mut<'a, M>: crate::MessageMutInterop<'a>
+ crate::__internal::runtime::UpbGetMessagePtrMut<Msg = M>
+ crate::__internal::runtime::UpbGetArena
+ crate::__internal::runtime::UpbGetMessagePtr,
{
crate::__internal::runtime::set_extension_message::<M, T>(m, self.number, val)
}
}
3 changes: 3 additions & 0 deletions rust/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub use crate::codegen_traits::{
pub use crate::cord::{ProtoBytesCow, ProtoStringCow};
pub use crate::map::{Map, MapIter, MapKey, MapMut, MapValue, MapView};

pub use crate::extension::Extension;

pub use crate::optional::Optional;
pub use crate::proxied::{
AsMut, AsView, IntoMut, IntoProxied, IntoView, Mut, MutProxied, Proxied, View,
Expand Down Expand Up @@ -67,6 +69,7 @@ mod codegen_traits;
mod cord;
#[path = "enum.rs"]
mod r#enum;
mod extension;
mod map;
mod optional;
mod primitive;
Expand Down
30 changes: 29 additions & 1 deletion rust/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# https://developers.google.com/open-source/licenses/bsd

load("@rules_pkg//pkg:mappings.bzl", "pkg_files", "strip_prefix")
load("@rules_rust//rust:defs.bzl", "rust_library")
load("@rules_rust//rust:defs.bzl", "rust_library", "rust_test")
load("//bazel:proto_library.bzl", "proto_library")
load(
"//rust:defs.bzl",
Expand Down Expand Up @@ -542,3 +542,31 @@ rust_library(
srcs = ["same_name_double_alias_exported_deps_consumer.rs"],
deps = [":same_name_double_alias_exported_deps_rust_proto"],
)

rust_test(
name = "extension_test",
srcs = ["extension_test.rs"],
aliases = {
"//rust:protobuf_cpp": "protobuf",
},
edition = "2021",
deps = [
":unittest_cpp_rust_proto",
"//rust:protobuf_cpp",
"@crate_index//:googletest",
],
)

rust_test(
name = "extension_test_upb",
srcs = ["extension_test.rs"],
aliases = {
"//rust:protobuf_upb": "protobuf",
},
edition = "2021",
deps = [
":unittest_upb_rust_proto",
"//rust:protobuf_upb",
"@crate_index//:googletest",
],
)
Loading
Loading