@@ -89,6 +89,7 @@ class dpctl_capi
8989
9090 // memory
9191 DPCTLSyclUSMRef (* Memory_GetUsmPointer_ )(Py_MemoryObject * );
92+ void * (* Memory_GetOpaquePointer_ )(Py_MemoryObject * );
9293 DPCTLSyclContextRef (* Memory_GetContextRef_ )(Py_MemoryObject * );
9394 DPCTLSyclQueueRef (* Memory_GetQueueRef_ )(Py_MemoryObject * );
9495 size_t (* Memory_GetNumBytes_ )(Py_MemoryObject * );
@@ -115,6 +116,7 @@ class dpctl_capi
115116 int (* UsmNDArray_GetFlags_ )(PyUSMArrayObject * );
116117 DPCTLSyclQueueRef (* UsmNDArray_GetQueueRef_ )(PyUSMArrayObject * );
117118 py ::ssize_t (* UsmNDArray_GetOffset_ )(PyUSMArrayObject * );
119+ PyObject * (* UsmNDArray_GetUSMData_ )(PyUSMArrayObject * );
118120 void (* UsmNDArray_SetWritableFlag_ )(PyUSMArrayObject * , int );
119121 PyObject * (* UsmNDArray_MakeSimpleFromMemory_ )(int ,
120122 const py ::ssize_t * ,
@@ -233,15 +235,16 @@ class dpctl_capi
233235 SyclContext_Make_ (nullptr ), SyclEvent_GetEventRef_ (nullptr ),
234236 SyclEvent_Make_ (nullptr ), SyclQueue_GetQueueRef_ (nullptr ),
235237 SyclQueue_Make_ (nullptr ), Memory_GetUsmPointer_ (nullptr ),
236- Memory_GetContextRef_ (nullptr ), Memory_GetQueueRef_ (nullptr ),
237- Memory_GetNumBytes_ (nullptr ), Memory_Make_ (nullptr ),
238- SyclKernel_GetKernelRef_ (nullptr ), SyclKernel_Make_ (nullptr ),
239- SyclProgram_GetKernelBundleRef_ (nullptr ), SyclProgram_Make_ (nullptr ),
240- UsmNDArray_GetData_ (nullptr ), UsmNDArray_GetNDim_ (nullptr ),
241- UsmNDArray_GetShape_ (nullptr ), UsmNDArray_GetStrides_ (nullptr ),
242- UsmNDArray_GetTypenum_ (nullptr ), UsmNDArray_GetElementSize_ (nullptr ),
243- UsmNDArray_GetFlags_ (nullptr ), UsmNDArray_GetQueueRef_ (nullptr ),
244- UsmNDArray_GetOffset_ (nullptr ), UsmNDArray_SetWritableFlag_ (nullptr ),
238+ Memory_GetOpaquePointer_ (nullptr ), Memory_GetContextRef_ (nullptr ),
239+ Memory_GetQueueRef_ (nullptr ), Memory_GetNumBytes_ (nullptr ),
240+ Memory_Make_ (nullptr ), SyclKernel_GetKernelRef_ (nullptr ),
241+ SyclKernel_Make_ (nullptr ), SyclProgram_GetKernelBundleRef_ (nullptr ),
242+ SyclProgram_Make_ (nullptr ), UsmNDArray_GetData_ (nullptr ),
243+ UsmNDArray_GetNDim_ (nullptr ), UsmNDArray_GetShape_ (nullptr ),
244+ UsmNDArray_GetStrides_ (nullptr ), UsmNDArray_GetTypenum_ (nullptr ),
245+ UsmNDArray_GetElementSize_ (nullptr ), UsmNDArray_GetFlags_ (nullptr ),
246+ UsmNDArray_GetQueueRef_ (nullptr ), UsmNDArray_GetOffset_ (nullptr ),
247+ UsmNDArray_GetUSMData_ (nullptr ), UsmNDArray_SetWritableFlag_ (nullptr ),
245248 UsmNDArray_MakeSimpleFromMemory_ (nullptr ),
246249 UsmNDArray_MakeSimpleFromPtr_ (nullptr ),
247250 UsmNDArray_MakeFromPtr_ (nullptr ), USM_ARRAY_C_CONTIGUOUS_ (0 ),
@@ -299,6 +302,7 @@ class dpctl_capi
299302
300303 // dpctl.memory API
301304 this -> Memory_GetUsmPointer_ = Memory_GetUsmPointer ;
305+ this -> Memory_GetOpaquePointer_ = Memory_GetOpaquePointer ;
302306 this -> Memory_GetContextRef_ = Memory_GetContextRef ;
303307 this -> Memory_GetQueueRef_ = Memory_GetQueueRef ;
304308 this -> Memory_GetNumBytes_ = Memory_GetNumBytes ;
@@ -320,6 +324,7 @@ class dpctl_capi
320324 this -> UsmNDArray_GetFlags_ = UsmNDArray_GetFlags ;
321325 this -> UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef ;
322326 this -> UsmNDArray_GetOffset_ = UsmNDArray_GetOffset ;
327+ this -> UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData ;
323328 this -> UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag ;
324329 this -> UsmNDArray_MakeSimpleFromMemory_ =
325330 UsmNDArray_MakeSimpleFromMemory ;
@@ -779,6 +784,33 @@ class usm_memory : public py::object
779784 return api .Memory_GetNumBytes_ (mem_obj );
780785 }
781786
787+ bool is_managed_by_smart_ptr () const
788+ {
789+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
790+ Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
791+ const void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
792+
793+ return bool (opaque_ptr );
794+ }
795+
796+ const std ::shared_ptr < void > & get_smart_ptr_owner () const
797+ {
798+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
799+ Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
800+ void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
801+
802+ if (opaque_ptr ) {
803+ auto shptr_ptr =
804+ reinterpret_cast < std ::shared_ptr < void > * > (opaque_ptr );
805+ return * shptr_ptr ;
806+ }
807+ else {
808+ throw std ::runtime_error (
809+ "Memory object does not have smart pointer "
810+ "managing lifetime of USM allocation" );
811+ }
812+ }
813+
782814protected :
783815 static PyObject * as_usm_memory (PyObject * o )
784816 {
@@ -1065,6 +1097,71 @@ class usm_ndarray : public py::object
10651097 return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
10661098 }
10671099
1100+ /*! @brief Get usm_data property of array */
1101+ py ::object get_usm_data () const
1102+ {
1103+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
1104+
1105+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1106+ // UsmNDArray_GetUSMData_ gives a new reference
1107+ PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1108+
1109+ // pass reference ownership to py::object
1110+ return py ::reinterpret_steal < py ::object > (usm_data );
1111+ }
1112+
1113+ bool is_managed_by_smart_ptr () const
1114+ {
1115+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
1116+
1117+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1118+ PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1119+
1120+ if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ )) {
1121+ Py_DECREF (usm_data );
1122+ return false;
1123+ }
1124+
1125+ Py_MemoryObject * mem_obj =
1126+ reinterpret_cast < Py_MemoryObject * > (usm_data );
1127+ const void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
1128+
1129+ Py_DECREF (usm_data );
1130+ return bool (opaque_ptr );
1131+ }
1132+
1133+ const std ::shared_ptr < void > & get_smart_ptr_owner () const
1134+ {
1135+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
1136+
1137+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1138+
1139+ PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1140+
1141+ if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ )) {
1142+ Py_DECREF (usm_data );
1143+ throw std ::runtime_error (
1144+ "usm_ndarray object does not have Memory object "
1145+ "managing lifetime of USM allocation" );
1146+ }
1147+
1148+ Py_MemoryObject * mem_obj =
1149+ reinterpret_cast < Py_MemoryObject * > (usm_data );
1150+ void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
1151+ Py_DECREF (usm_data );
1152+
1153+ if (opaque_ptr ) {
1154+ auto shptr_ptr =
1155+ reinterpret_cast < std ::shared_ptr < void > * > (opaque_ptr );
1156+ return * shptr_ptr ;
1157+ }
1158+ else {
1159+ throw std ::runtime_error (
1160+ "Memory object underlying usm_ndarray does not have "
1161+ "smart pointer managing lifetime of USM allocation" );
1162+ }
1163+ }
1164+
10681165private :
10691166 PyUSMArrayObject * usm_array_ptr () const
10701167 {
@@ -1077,26 +1174,112 @@ class usm_ndarray : public py::object
10771174namespace utils
10781175{
10791176
1177+ namespace detail
1178+ {
1179+
1180+ struct ManagedMemory
1181+ {
1182+
1183+ static bool is_usm_managed_by_shared_ptr (const py ::object & h )
1184+ {
1185+ if (py ::isinstance < dpctl ::memory ::usm_memory > (h )) {
1186+ const auto & usm_memory_inst =
1187+ py ::cast < dpctl ::memory ::usm_memory > (h );
1188+ return usm_memory_inst .is_managed_by_smart_ptr ();
1189+ }
1190+ else if (py ::isinstance < dpctl ::tensor ::usm_ndarray > (h )) {
1191+ const auto & usm_array_inst =
1192+ py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1193+ return usm_array_inst .is_managed_by_smart_ptr ();
1194+ }
1195+
1196+ return false;
1197+ }
1198+
1199+ static const std ::shared_ptr < void > & extract_shared_ptr (const py ::object & h )
1200+ {
1201+ if (py ::isinstance < dpctl ::memory ::usm_memory > (h )) {
1202+ const auto & usm_memory_inst =
1203+ py ::cast < dpctl ::memory ::usm_memory > (h );
1204+ return usm_memory_inst .get_smart_ptr_owner ();
1205+ }
1206+ else if (py ::isinstance < dpctl ::tensor ::usm_ndarray > (h )) {
1207+ const auto & usm_array_inst =
1208+ py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1209+ return usm_array_inst .get_smart_ptr_owner ();
1210+ }
1211+
1212+ throw std ::runtime_error (
1213+ "Attempted extraction of shared_ptr on an unrecognized type" );
1214+ }
1215+ };
1216+
1217+ } // end of namespace detail
1218+
10801219template < std ::size_t num >
10811220sycl ::event keep_args_alive (sycl ::queue & q ,
10821221 const py ::object (& py_objs )[num ],
10831222 const std ::vector < sycl ::event > & depends = {})
10841223{
1085- sycl ::event host_task_ev = q .submit ([& ](sycl ::handler & cgh ) {
1086- cgh .depends_on (depends );
1087- std ::array < std ::shared_ptr < py ::handle > , num > shp_arr ;
1088- for (std ::size_t i = 0 ; i < num ; ++ i ) {
1089- shp_arr [i ] = std ::make_shared < py ::handle > (py_objs [i ]);
1090- shp_arr [i ]-> inc_ref ();
1224+ std ::size_t n_objects_held = 0 ;
1225+ std ::array < std ::shared_ptr < py ::handle > , num > shp_arr {};
1226+
1227+ std ::size_t n_usm_owners_held = 0 ;
1228+ std ::array < std ::shared_ptr < void > , num > shp_usm {};
1229+
1230+ for (std ::size_t i = 0 ; i < num ; ++ i ) {
1231+ const auto & py_obj_i = py_objs [i ];
1232+ if (detail ::ManagedMemory ::is_usm_managed_by_shared_ptr (py_obj_i )) {
1233+ const auto & shp =
1234+ detail ::ManagedMemory ::extract_shared_ptr (py_obj_i );
1235+ shp_usm [n_usm_owners_held ] = shp ;
1236+ ++ n_usm_owners_held ;
10911237 }
1092- cgh .host_task ([shp_arr = std ::move (shp_arr )]() {
1093- py ::gil_scoped_acquire acquire ;
1238+ else {
1239+ shp_arr [n_objects_held ] = std ::make_shared < py ::handle > (py_obj_i );
1240+ shp_arr [n_objects_held ]-> inc_ref ();
1241+ ++ n_objects_held ;
1242+ }
1243+ }
10941244
1095- for (std ::size_t i = 0 ; i < num ; ++ i ) {
1096- shp_arr [i ]-> dec_ref ();
1245+ bool use_depends = true;
1246+ sycl ::event host_task_ev ;
1247+
1248+ if (n_usm_owners_held > 0 ) {
1249+ host_task_ev = q .submit ([& ](sycl ::handler & cgh ) {
1250+ if (use_depends ) {
1251+ cgh .depends_on (depends );
1252+ use_depends = false;
10971253 }
1254+ else {
1255+ cgh .depends_on (host_task_ev );
1256+ }
1257+ cgh .host_task ([shp_usm = std ::move (shp_usm )]() {
1258+ // no body, but shared pointers are captured in
1259+ // the lambda, ensuring that USM allocation is
1260+ // kept alive
1261+ });
1262+ });
1263+ }
1264+
1265+ if (n_objects_held > 0 ) {
1266+ host_task_ev = q .submit ([& ](sycl ::handler & cgh ) {
1267+ if (use_depends ) {
1268+ cgh .depends_on (depends );
1269+ use_depends = false;
1270+ }
1271+ else {
1272+ cgh .depends_on (host_task_ev );
1273+ }
1274+ cgh .host_task ([n_objects_held , shp_arr = std ::move (shp_arr )]() {
1275+ py ::gil_scoped_acquire acquire ;
1276+
1277+ for (std ::size_t i = 0 ; i < n_objects_held ; ++ i ) {
1278+ shp_arr [i ]-> dec_ref ();
1279+ }
1280+ });
10981281 });
1099- });
1282+ }
11001283
11011284 return host_task_ev ;
11021285}
0 commit comments