Skip to content

Commit 6c4f569

Browse files
committed
Update sdk
1 parent 10f94be commit 6c4f569

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

aimon/decorators/detect.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
122122
api_key = os.getenv('AIMON_API_KEY') if not api_key else api_key
123123
if api_key is None:
124124
raise ValueError("API key is None")
125-
self.client = Client(auth_header="Bearer {}".format(api_key))
125+
self.client = Client(auth_header="Bearer {}".format(api_key), base_url='https://am-sdk-backend-staging-ser-6009-0c0ad782-m9xwngeb.onporter.run/')
126126
self.config = config if config else self.DEFAULT_CONFIG
127127
self.values_returned = values_returned
128128
if self.values_returned is None or len(self.values_returned) == 0:
@@ -137,7 +137,10 @@ def __init__(self, values_returned, api_key=None, config=None, async_mode=False,
137137
if application_name is None:
138138
raise ValueError("Application name must be provided if publish is True")
139139
if model_name is None:
140-
raise ValueError("Model name must be provided if publish is True")
140+
raise ValueError("Model name must be provided if publish is True")
141+
142+
self.application_name = application_name
143+
self.model_name = model_name
141144

142145
def __call__(self, func):
143146
@wraps(func)
@@ -164,13 +167,34 @@ def wrapper(*args, **kwargs):
164167
aimon_payload['user_query'] = result_dict['user_query']
165168
if 'instructions' in result_dict:
166169
aimon_payload['instructions'] = result_dict['instructions']
170+
167171
aimon_payload['config'] = self.config
168172
aimon_payload['publish'] = self.publish
169173
aimon_payload['async_mode'] = self.async_mode
170174

175+
# Include application_name and model_name if publishing
176+
if self.publish:
177+
aimon_payload['application_name'] = self.application_name
178+
aimon_payload['model_name'] = self.model_name
179+
171180
data_to_send = [aimon_payload]
172181

173-
detect_response = self.client.inference.detect(body=data_to_send)[0]
174-
return result + (DetectResult(200 if detect_response is not None else 500, detect_response),)
182+
try:
183+
detect_response = self.client.inference.detect(body=data_to_send)
184+
# Check if the response is a list
185+
if isinstance(detect_response, list) and len(detect_response) > 0:
186+
detect_result = detect_response[0]
187+
elif isinstance(detect_response, dict):
188+
detect_result = detect_response # Single dict response
189+
else:
190+
raise ValueError("Unexpected response format from detect API: {}".format(detect_response))
191+
except Exception as e:
192+
# Log the error and raise it
193+
print(f"Error during detection: {e}")
194+
raise
195+
196+
# Return the original result along with the DetectResult
197+
return result + (DetectResult(200 if detect_result else 500, detect_result),)
198+
175199

176200
return wrapper

aimon/types/inference_detect_params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,9 @@ class Body(TypedDict, total=False):
8080

8181
async_mode: bool
8282
"""If True, the detect() function will return immediately with a DetectResult object. Default is False."""
83+
84+
application_name: str
85+
"""Application name"""
86+
87+
model_name: str
88+
"""Model name"""

0 commit comments

Comments
 (0)