@@ -79,6 +79,50 @@ class _CallableModel(BaseModel, abc.ABC):
7979 )
8080 meta : MetaData = Field (default_factory = MetaData )
8181
82+ @classmethod
83+ def _check_context_type (cls , context_type ):
84+ type_call_arg = _cached_signature (cls .__call__ ).parameters ["context" ].annotation
85+
86+ # If optional type, extract inner type
87+ if get_origin (type_call_arg ) is Optional or (get_origin (type_call_arg ) is Union and type (None ) in get_args (type_call_arg )):
88+ type_call_arg = [t for t in get_args (type_call_arg ) if t is not type (None )][0 ]
89+
90+ if (
91+ not isinstance (type_call_arg , TypeVar )
92+ and type_call_arg is not Signature .empty
93+ and (not isclass (type_call_arg ) or not issubclass (type_call_arg , context_type ))
94+ and (not isclass (context_type ) or not issubclass (context_type , type_call_arg ))
95+ ):
96+ err_msg_type_mismatch = f"The context_type { context_type } must match the type of the context accepted by __call__ { type_call_arg } "
97+ raise ValueError (err_msg_type_mismatch )
98+
99+ @classmethod
100+ def _check_result_type (cls , result_type ):
101+ type_call_return = _cached_signature (cls .__call__ ).return_annotation
102+
103+ # If union, check all types
104+ if get_origin (type_call_return ) is Union and get_args (type_call_return ):
105+ types_call_return = [t for t in get_args (type_call_return ) if t is not type (None )]
106+ else :
107+ types_call_return = [type_call_return ]
108+
109+ all_bad = True
110+ for type_call_return in types_call_return :
111+ if (
112+ not isinstance (type_call_return , TypeVar )
113+ and type_call_return is not Signature .empty
114+ and (not isclass (type_call_return ) or not issubclass (type_call_return , result_type ))
115+ and (not isclass (result_type ) or not issubclass (result_type , type_call_return ))
116+ ):
117+ # Don't invert logic so that we match context above
118+ pass
119+ else :
120+ all_bad = False
121+
122+ if all_bad :
123+ err_msg_type_mismatch = f"The result_type { result_type } must match the return type of __call__ { type_call_return } "
124+ raise ValueError (err_msg_type_mismatch )
125+
82126 @model_validator (mode = "after" )
83127 def _check_signature (self ):
84128 sig_call = _cached_signature (self .__class__ .__call__ )
@@ -98,50 +142,12 @@ def _check_signature(self):
98142 )
99143 raise ValueError (err_msg_type_mismatch )
100144
101- # If context_type or result_type are overridden, ensure they match the signature
145+ # If context_type or result_type are overridden or
146+ # come from generic type, ensure they match the signature
102147 if hasattr (self , "context_type" ):
103- type_call_arg = _cached_signature (self .__class__ .__call__ ).parameters ["context" ].annotation
104-
105- # If optional type, extract inner type
106- if get_origin (type_call_arg ) is Optional or (get_origin (type_call_arg ) is Union and type (None ) in get_args (type_call_arg )):
107- type_call_arg = [t for t in get_args (type_call_arg ) if t is not type (None )][0 ]
108-
109- if (
110- not isinstance (type_call_arg , TypeVar )
111- and type_call_arg is not Signature .empty
112- and (not isclass (type_call_arg ) or not issubclass (type_call_arg , self .context_type ))
113- and (not isclass (self .context_type ) or not issubclass (self .context_type , type_call_arg ))
114- ):
115- err_msg_type_mismatch = (
116- f"The context_type { self .context_type } must match the type of the context accepted by __call__ { type_call_arg } "
117- )
118- raise ValueError (err_msg_type_mismatch )
119-
148+ self ._check_context_type (self .context_type )
120149 if hasattr (self , "result_type" ):
121- type_call_return = _cached_signature (self .__class__ .__call__ ).return_annotation
122-
123- # If union, check all types
124- if get_origin (type_call_return ) is Union and get_args (type_call_return ):
125- types_call_return = [t for t in get_args (type_call_return ) if t is not type (None )]
126- else :
127- types_call_return = [type_call_return ]
128-
129- all_bad = True
130- for type_call_return in types_call_return :
131- if (
132- not isinstance (type_call_return , TypeVar )
133- and type_call_return is not Signature .empty
134- and (not isclass (type_call_return ) or not issubclass (type_call_return , self .result_type ))
135- and (not isclass (self .result_type ) or not issubclass (self .result_type , type_call_return ))
136- ):
137- # Don't invert logic so that we match context above
138- pass
139- else :
140- all_bad = False
141-
142- if all_bad :
143- err_msg_type_mismatch = f"The result_type { self .result_type } must match the return type of __call__ { type_call_return } "
144- raise ValueError (err_msg_type_mismatch )
150+ self ._check_result_type (self .result_type )
145151
146152 return self
147153
@@ -548,13 +554,17 @@ def context_type(self) -> Type[ContextType]:
548554 """
549555 typ = _cached_signature (self .__class__ .__call__ ).parameters ["context" ].annotation
550556 if typ is Signature .empty :
551- raise TypeError ("Must either define a type annotation for context on __call__ or implement 'context_type'" )
552-
553- self ._check_context_type (typ )
554- return typ
557+ if isinstance (self , CallableModelGenericType ) and hasattr (self , "_context_generic_type" ):
558+ typ = self ._context_generic_type
559+ else :
560+ raise TypeError ("Must either define a type annotation for context on __call__ or implement 'context_type'" )
561+ elif (
562+ isinstance (self , CallableModelGenericType ) and hasattr (self , "_context_generic_type" ) and not issubclass (typ , self ._context_generic_type )
563+ ):
564+ raise TypeError (
565+ f"Context type annotation { typ } on __call__ does not match context_type { self ._context_generic_type } defined by CallableModelGenericType"
566+ )
555567
556- @staticmethod
557- def _check_context_type (typ ):
558568 # If optional type, extract inner type
559569 if get_origin (typ ) is Optional or (get_origin (typ ) is Union and type (None ) in get_args (typ )):
560570 type_to_check = [t for t in get_args (typ ) if t is not type (None )][0 ]
@@ -565,6 +575,8 @@ def _check_context_type(typ):
565575 if not isclass (type_to_check ) or not issubclass (type_to_check , ContextBase ):
566576 raise TypeError (f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received { type_to_check } ." )
567577
578+ return typ
579+
568580 @property
569581 def result_type (self ) -> Type [ResultType ]:
570582 """Return the result type for the model.
@@ -574,13 +586,29 @@ def result_type(self) -> Type[ResultType]:
574586 """
575587 typ = _cached_signature (self .__class__ .__call__ ).return_annotation
576588 if typ is Signature .empty :
577- raise TypeError ("Must either define a return type annotation on __call__ or implement 'result_type'" )
578-
579- self ._check_result_type (typ )
580- return typ
589+ if isinstance (self , CallableModelGenericType ) and hasattr (self , "_result_generic_type" ):
590+ typ = self ._result_generic_type
591+ else :
592+ raise TypeError ("Must either define a return type annotation on __call__ or implement 'result_type'" )
593+ elif isinstance (self , CallableModelGenericType ) and hasattr (self , "_result_generic_type" ):
594+ if get_origin (typ ) is Union and get_origin (self ._result_generic_type ) is Union :
595+ if set (get_args (typ )) != set (get_args (self ._result_generic_type )):
596+ raise TypeError (
597+ f"Return type annotation { typ } on __call__ does not match result_type { self ._result_generic_type } defined by CallableModelGenericType"
598+ )
599+ elif get_origin (typ ) is Union :
600+ raise NotImplementedError (
601+ "Return type annotation on __call__ is a Union, but result_type defined by CallableModelGenericType is not a Union. This case is not yet supported."
602+ )
603+ elif get_origin (self ._result_generic_type ) is Union :
604+ raise NotImplementedError (
605+ "Return type annotation on __call__ is not a Union, but result_type defined by CallableModelGenericType is a Union. This case is not yet supported."
606+ )
607+ elif not issubclass (typ , self ._result_generic_type ):
608+ raise TypeError (
609+ f"Return type annotation { typ } on __call__ does not match result_type { self ._result_generic_type } defined by CallableModelGenericType"
610+ )
581611
582- @staticmethod
583- def _check_result_type (typ ):
584612 # If union type, extract inner type
585613 if get_origin (typ ) is Union :
586614 raise TypeError (
@@ -590,6 +618,7 @@ def _check_result_type(typ):
590618 # Ensure subclass of ResultBase
591619 if not isclass (typ ) or not issubclass (typ , ResultBase ):
592620 raise TypeError (f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received { typ } ." )
621+ return typ
593622
594623 @Flow .deps
595624 def __deps__ (
@@ -625,51 +654,48 @@ def result_type(self) -> Type[ResultType]:
625654 return self .model .result_type
626655
627656
628- class CallableModelGenericType (CallableModel , Generic [ContextType , ResultType ]):
657+ class CallableModelGeneric (CallableModel , Generic [ContextType , ResultType ]):
629658 """Special type of callable model that provides context and result via
630659 a generic type instead of annotations on __call__.
631660 """
632661
633- _context_type : ClassVar [Type [ContextType ]]
634- _result_type : ClassVar [Type [ResultType ]]
635-
636- @property
637- def context_type (self ) -> Type [ContextType ]:
638- return self ._context_type
639-
640- @property
641- def result_type (self ) -> Type [ResultType ]:
642- return self ._result_type
662+ _context_generic_type : ClassVar [Type [ContextType ]]
663+ _result_generic_type : ClassVar [Type [ResultType ]]
643664
644665 def __setstate__ (self , state ):
645- super ().__setstate__ (state )
646666 self ._determine_context_result ()
667+ super ().__setstate__ (state )
668+
669+ @classmethod
670+ def __pydantic_init_subclass__ (cls , ** kwargs ):
671+ super ().__pydantic_init_subclass__ (** kwargs )
672+ cls ._determine_context_result ()
647673
648674 @classmethod
649675 def _determine_context_result (cls ):
650676 # Extract the generic types from the class definition
651- if not hasattr (cls , "_context_type " ) or not hasattr (cls , "_result_type " ):
677+ if not hasattr (cls , "_context_generic_type " ) or not hasattr (cls , "_result_generic_type " ):
652678 new_context_type = None
653679 new_result_type = None
654680
655681 for base in cls .__mro__ :
656- if issubclass (base , CallableModelGenericType ):
682+ if issubclass (base , CallableModelGeneric ):
657683 # Found the generic base class, it should
658684 # have either generic parameters or context/result
659- if new_context_type is None and hasattr (base , "_context_type " ) and issubclass (base ._context_type , ContextBase ):
660- new_context_type = base ._context_type
685+ if new_context_type is None and hasattr (base , "_context_generic_type " ) and issubclass (base ._context_generic_type , ContextBase ):
686+ new_context_type = base ._context_generic_type
661687 if (
662688 new_result_type is None
663- and hasattr (base , "_result_type " )
689+ and hasattr (base , "_result_generic_type " )
664690 and (
665- issubclass (base ._result_type , ResultBase )
691+ issubclass (base ._result_generic_type , ResultBase )
666692 or (
667- get_origin (base ._result_type ) is Union
668- and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (base ._result_type ))
693+ get_origin (base ._result_generic_type ) is Union
694+ and all (isclass (t ) and issubclass (t , ResultBase ) for t in get_args (base ._result_generic_type ))
669695 )
670696 )
671697 ):
672- new_result_type = base ._result_type
698+ new_result_type = base ._result_generic_type
673699 if base .__pydantic_generic_metadata__ ["args" ]:
674700 if len (base .__pydantic_generic_metadata__ ["args" ]) >= 2 :
675701 # Assume order is ContextType, ResultType
@@ -696,56 +722,12 @@ def _determine_context_result(cls):
696722 break
697723
698724 if new_context_type is not None :
699- # Validate that the model's context_type match
700- annotation_context_type = _cached_signature (cls .__call__ ).parameters ["context" ].annotation
701- if get_origin (annotation_context_type ) is Optional or (
702- get_origin (annotation_context_type ) is Union and type (None ) in get_args (annotation_context_type )
703- ):
704- annotation_context_type = [t for t in get_args (annotation_context_type ) if t is not type (None )][0 ]
705- if (
706- annotation_context_type is not Signature .empty
707- and not isinstance (annotation_context_type , TypeVar )
708- and not issubclass (annotation_context_type , new_context_type )
709- ):
710- raise TypeError (
711- f"Context type annotation { annotation_context_type } on __call__ does not match context_type { new_context_type } defined by CallableModelGenericType"
712- )
713- elif isclass (annotation_context_type ) and issubclass (annotation_context_type , new_context_type ):
714- new_context_type = annotation_context_type
715-
716725 # Set on class
717- cls ._context_type = new_context_type
726+ cls ._context_generic_type = new_context_type
718727
719728 if new_result_type is not None :
720- # Validate that the model's result_type match
721- annotation_result_type = _cached_signature (cls .__call__ ).return_annotation
722- if annotation_result_type is Signature .empty :
723- ...
724- elif isinstance (annotation_result_type , TypeVar ):
725- ...
726- elif get_origin (annotation_result_type ) is Union and get_origin (new_result_type ) is Union :
727- raise TypeError (
728- f"Return type annotation for __call__ cannot be union on a CallableModelGenericType with union `result_type`. Received { annotation_result_type } "
729- )
730- elif get_origin (annotation_result_type ) is Union :
731- if not any (issubclass (new_result_type , union_type ) for union_type in get_args (annotation_result_type )):
732- raise TypeError (
733- f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
734- )
735- elif get_origin (new_result_type ) is Union :
736- if not any (issubclass (annotation_result_type , union_type ) for union_type in get_args (new_result_type )):
737- raise TypeError (
738- f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
739- )
740- elif not issubclass (annotation_result_type , new_result_type ):
741- raise TypeError (
742- f"Return type annotation { annotation_result_type } on __call__ does not match result_type { new_result_type } defined by CallableModelGenericType"
743- )
744- elif isclass (annotation_result_type ) and issubclass (annotation_result_type , new_result_type ):
745- new_result_type = annotation_result_type
746-
747729 # Set on class
748- cls ._result_type = new_result_type
730+ cls ._result_generic_type = new_result_type
749731
750732 @model_validator (mode = "wrap" )
751733 def _validate_callable_model_generic_type (cls , m , handler , info ):
@@ -756,7 +738,8 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
756738
757739 if isinstance (m , dict ):
758740 m = handler (m )
759- cls ._determine_context_result ()
741+ elif isinstance (m , cls ):
742+ m = handler (m )
760743
761744 # Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/
762745 if not isinstance (m , CallableModel ):
@@ -768,3 +751,6 @@ def _validate_callable_model_generic_type(cls, m, handler, info):
768751 TypeAdapter (Type [subtypes [1 ]]).validate_python (m .result_type )
769752
770753 return m
754+
755+
756+ CallableModelGenericType = CallableModelGeneric
0 commit comments