99import openai
1010import tiktoken
1111from langfuse .model import InitialGeneration , Usage
12+ from openai import OpenAI
1213from tenacity import *
1314
1415from pentestgpt .utils .llm_api import LLMAPI
@@ -46,6 +47,8 @@ def __eq__(self, other):
4647class ChatGPTAPI (LLMAPI ):
4748 def __init__ (self , config_class , use_langfuse_logging = False ):
4849 self .name = str (config_class .model )
50+ api_key = os .getenv ("OPENAI_API_KEY" , None )
51+ self .client = OpenAI (api_key = api_key , base_url = config_class .api_base )
4952
5053 if use_langfuse_logging :
5154 # use langfuse.openai to shadow the default openai library
@@ -58,9 +61,7 @@ def __init__(self, config_class, use_langfuse_logging=False):
5861 from langfuse import Langfuse
5962
6063 self .langfuse = Langfuse ()
61-
62- openai .api_key = os .getenv ("OPENAI_API_KEY" , None )
63- openai .api_base = config_class .api_base
64+
6465 self .model = config_class .model
6566 self .log_dir = config_class .log_dir
6667 self .history_length = 5 # maintain 5 messages in the history. (5 chat memory)
@@ -69,7 +70,9 @@ def __init__(self, config_class, use_langfuse_logging=False):
6970
7071 logger .add (sink = os .path .join (self .log_dir , "chatgpt.log" ), level = "WARNING" )
7172
72- def _chat_completion (self , history : List , model = None , temperature = 0.5 ) -> str :
73+ def _chat_completion (
74+ self , history : List , model = None , temperature = 0.5 , image_url : str = None
75+ ) -> str :
7376 generationStartTime = datetime .now ()
7477 # use model if provided, otherwise use self.model; if self.model is None, use gpt-4-1106-preview
7578 if model is None :
@@ -78,12 +81,12 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
7881 else :
7982 model = self .model
8083 try :
81- response = openai . ChatCompletion .create (
84+ response = self . client . chat . completions .create (
8285 model = model ,
8386 messages = history ,
8487 temperature = temperature ,
8588 )
86- except openai .error .APIConnectionError as e : # give one more try
89+ except openai ._exceptions .APIConnectionError as e : # give one more try
8790 logger .warning (
8891 "API Connection Error. Waiting for {} seconds" .format (
8992 self .error_wait_time
@@ -96,7 +99,7 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
9699 messages = history ,
97100 temperature = temperature ,
98101 )
99- except openai .error .RateLimitError as e : # give one more try
102+ except openai ._exceptions .RateLimitError as e : # give one more try
100103 logger .warning ("Rate limit reached. Waiting for 5 seconds" )
101104 logger .error ("Rate Limit Error: " , e )
102105 time .sleep (5 )
@@ -105,7 +108,7 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
105108 messages = history ,
106109 temperature = temperature ,
107110 )
108- except openai .error . InvalidRequestError as e : # token limit reached
111+ except openai ._exceptions . RateLimitError as e : # token limit reached
109112 logger .warning ("Token size limit reached. The recent message is compressed" )
110113 logger .error ("Token size error; will retry with compressed message " , e )
111114 # compress the message in two ways.
@@ -151,14 +154,14 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
151154 model = self .model ,
152155 modelParameters = {"temperature" : str (temperature )},
153156 prompt = history ,
154- completion = response [ " choices" ] [0 ][ " message" ][ " content" ] ,
157+ completion = response . choices [0 ]. message . content ,
155158 usage = Usage (
156- promptTokens = response [ " usage" ][ " prompt_tokens" ] ,
157- completionTokens = response [ " usage" ][ " completion_tokens" ] ,
159+ promptTokens = response . usage . prompt_tokens ,
160+ completionTokens = response . usage . completion_tokens ,
158161 ),
159162 )
160163 )
161- return response [ " choices" ] [0 ][ " message" ][ " content" ]
164+ return response . choices [0 ]. message . content
162165
163166
164167if __name__ == "__main__" :
0 commit comments