44
55from __future__ import annotations
66
7- import weakref
7+ from cuda.core.experimental._utils.cuda_utils cimport (
8+ _check_driver_error as raise_if_driver_error,
9+ check_or_create_options,
10+ )
11+
812from dataclasses import dataclass
913from typing import TYPE_CHECKING, Optional
1014
1115from cuda.core.experimental._context import Context
1216from cuda.core.experimental._utils.cuda_utils import (
1317 CUDAError,
14- check_or_create_options ,
1518 driver,
1619 handle_return,
1720)
18- from cuda .core .experimental ._utils .cuda_utils import (
19- _check_driver_error as raise_if_driver_error ,
20- )
2121
2222if TYPE_CHECKING:
2323 import cuda.bindings
2424 from cuda.core.experimental._device import Device
2525
2626
2727@dataclass
28- class EventOptions :
28+ cdef class EventOptions:
2929 """ Customizable :obj:`~_event.Event` options.
3030
3131 Attributes
@@ -49,7 +49,7 @@ class EventOptions:
4949 support_ipc: Optional[bool ] = False
5050
5151
52- class Event :
52+ cdef class Event:
5353 """ Represent a record at a specific point of execution within a CUDA stream.
5454
5555 Applications can asynchronously record events at any point in
@@ -77,49 +77,46 @@ class Event:
7777 and they should instead be created through a :obj:`~_stream.Stream` object.
7878
7979 """
80-
81- class _MembersNeededForFinalize :
82- __slots__ = ("handle" ,)
83-
84- def __init__ (self , event_obj , handle ):
85- self .handle = handle
86- weakref .finalize (event_obj , self .close )
87-
88- def close (self ):
89- if self .handle is not None :
90- handle_return (driver .cuEventDestroy (self .handle ))
91- self .handle = None
92-
93- def __new__ (self , * args , ** kwargs ):
80+ cdef:
81+ object _handle
82+ bint _timing_disabled
83+ bint _busy_waited
84+ int _device_id
85+ object _ctx_handle
86+
87+ def __init__ (self , *args , **kwargs ):
9488 raise RuntimeError (" Event objects cannot be instantiated directly. Please use Stream APIs (record)." )
9589
96- __slots__ = ("__weakref__" , "_mnff" , "_timing_disabled" , "_busy_waited" , "_device_id" , "_ctx_handle" )
97-
9890 @classmethod
99- def _init (cls , device_id : int , ctx_handle : Context , options : Optional [EventOptions ] = None ):
100- self = super ().__new__ (cls )
101- self ._mnff = Event ._MembersNeededForFinalize (self , None )
102-
103- options = check_or_create_options (EventOptions , options , "Event options" )
91+ def _init (cls , device_id: int , ctx_handle: Context , options = None ):
92+ cdef Event self = Event.__new__ (Event)
93+ cdef EventOptions opts = check_or_create_options(EventOptions, options, " Event options" )
10494 flags = 0x0
10595 self ._timing_disabled = False
10696 self ._busy_waited = False
107- if not options .enable_timing :
97+ if not opts .enable_timing:
10898 flags |= driver.CUevent_flags.CU_EVENT_DISABLE_TIMING
10999 self ._timing_disabled = True
110- if options .busy_waited_sync :
100+ if opts .busy_waited_sync:
111101 flags |= driver.CUevent_flags.CU_EVENT_BLOCKING_SYNC
112102 self ._busy_waited = True
113- if options .support_ipc :
103+ if opts .support_ipc:
114104 raise NotImplementedError (" WIP: https://github.com/NVIDIA/cuda-python/issues/103" )
115- self ._mnff .handle = handle_return (driver .cuEventCreate (flags ))
105+ err, self ._handle = driver.cuEventCreate(flags)
106+ raise_if_driver_error(err)
116107 self ._device_id = device_id
117108 self ._ctx_handle = ctx_handle
118109 return self
119110
120- def close (self ):
111+ cpdef close(self ):
121112 """ Destroy the event."""
122- self ._mnff .close ()
113+ if self ._handle is not None :
114+ err, = driver.cuEventDestroy(self ._handle)
115+ self ._handle = None
116+ raise_if_driver_error(err)
117+
118+ def __del__ (self ):
119+ self .close()
123120
124121 def __isub__ (self , other ):
125122 return NotImplemented
@@ -129,7 +126,7 @@ def __rsub__(self, other):
129126
130127 def __sub__ (self , other ):
131128 # return self - other (in milliseconds)
132- err , timing = driver .cuEventElapsedTime (other .handle , self .handle )
129+ err, timing = driver.cuEventElapsedTime(other.handle, self ._handle )
133130 try :
134131 raise_if_driver_error(err)
135132 return timing
@@ -180,12 +177,12 @@ def sync(self):
180177 has been completed.
181178
182179 """
183- handle_return (driver .cuEventSynchronize (self ._mnff . handle ))
180+ handle_return(driver.cuEventSynchronize(self ._handle ))
184181
185182 @property
186183 def is_done (self ) -> bool:
187184 """Return True if all captured works have been completed , otherwise False."""
188- ( result ,) = driver .cuEventQuery (self ._mnff . handle )
185+ result , = driver.cuEventQuery(self._handle )
189186 if result == driver.CUresult.CUDA_SUCCESS:
190187 return True
191188 if result == driver.CUresult.CUDA_ERROR_NOT_READY:
@@ -201,7 +198,7 @@ def handle(self) -> cuda.bindings.driver.CUevent:
201198 This handle is a Python object. To get the memory address of the underlying C
202199 handle , call ``int(Event.handle )``.
203200 """
204- return self ._mnff . handle
201+ return self._handle
205202
206203 @property
207204 def device(self ) -> Device:
0 commit comments