Skip to content

Commit 1cc054f

Browse files
authored
fix: endpoint class fix (deepgram#551)
1 parent cab41fa commit 1cc054f

File tree

2 files changed

+301
-8
lines changed

2 files changed

+301
-8
lines changed

deepgram/clients/agent/v1/websocket/options.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,12 @@ class Endpoint(BaseResponse):
9393

9494
method: Optional[str] = field(default="POST")
9595
url: str = field(default="")
96-
headers: Optional[List[Header]] = field(
96+
headers: Optional[Dict[str, str]] = field(
9797
default=None, metadata=dataclass_config(exclude=lambda f: f is None)
9898
)
9999

100100
def __getitem__(self, key):
101101
_dict = self.to_dict()
102-
if "headers" in _dict:
103-
_dict["headers"] = [
104-
Header.from_dict(headers) for headers in _dict["headers"]
105-
]
106102
return _dict[key]
107103

108104

@@ -116,7 +112,7 @@ class Function(BaseResponse):
116112
description: str
117113
url: str
118114
method: str
119-
headers: Optional[List[Header]] = field(
115+
headers: Optional[Dict[str, str]] = field(
120116
default=None, metadata=dataclass_config(exclude=lambda f: f is None)
121117
)
122118
parameters: Optional[Parameters] = field(
@@ -130,8 +126,6 @@ def __getitem__(self, key):
130126
_dict = self.to_dict()
131127
if "parameters" in _dict and isinstance(_dict["parameters"], dict):
132128
_dict["parameters"] = Parameters.from_dict(_dict["parameters"])
133-
if "headers" in _dict and isinstance(_dict["headers"], list):
134-
_dict["headers"] = [Header.from_dict(header) for header in _dict["headers"]]
135129
if "endpoint" in _dict and isinstance(_dict["endpoint"], dict):
136130
_dict["endpoint"] = Endpoint.from_dict(_dict["endpoint"])
137131
return _dict[key]
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
# Copyright 2024 Deepgram SDK contributors. All Rights Reserved.
2+
# Use of this source code is governed by a MIT license that can be found in the LICENSE file.
3+
# SPDX-License-Identifier: MIT
4+
5+
import pytest
6+
import json
7+
from unittest.mock import patch, MagicMock
8+
9+
from deepgram import (
10+
DeepgramClient,
11+
SettingsOptions,
12+
Endpoint,
13+
Function,
14+
Header,
15+
)
16+
17+
18+
class TestEndpointHeaders:
19+
"""Unit tests for Endpoint.headers functionality using dictionary format"""
20+
21+
def test_endpoint_headers_dict_format(self):
22+
"""Test that Endpoint accepts headers as a dictionary"""
23+
headers = {"authorization": "Bearer token", "content-type": "application/json"}
24+
endpoint = Endpoint(
25+
url="https://api.example.com/v1/test",
26+
headers=headers
27+
)
28+
29+
assert endpoint.headers == headers
30+
assert endpoint.headers["authorization"] == "Bearer token"
31+
assert endpoint.headers["content-type"] == "application/json"
32+
33+
def test_endpoint_headers_serialization(self):
34+
"""Test that Endpoint with dict headers serializes correctly to JSON"""
35+
headers = {"authorization": "Bearer token"}
36+
endpoint = Endpoint(
37+
url="https://api.example.com/v1/test",
38+
headers=headers
39+
)
40+
41+
# Test direct JSON serialization
42+
json_data = endpoint.to_json()
43+
parsed = json.loads(json_data)
44+
45+
assert parsed["headers"] == headers
46+
assert parsed["headers"]["authorization"] == "Bearer token"
47+
assert parsed["url"] == "https://api.example.com/v1/test"
48+
assert parsed["method"] == "POST" # default value
49+
50+
def test_endpoint_headers_none(self):
51+
"""Test that Endpoint works correctly with None headers"""
52+
endpoint = Endpoint(url="https://api.example.com/v1/test")
53+
54+
assert endpoint.headers is None
55+
56+
# Test serialization with None headers
57+
json_data = endpoint.to_json()
58+
parsed = json.loads(json_data)
59+
60+
assert "headers" not in parsed # Should be excluded when None
61+
62+
def test_endpoint_headers_empty_dict(self):
63+
"""Test that Endpoint works correctly with empty dict headers"""
64+
endpoint = Endpoint(
65+
url="https://api.example.com/v1/test",
66+
headers={}
67+
)
68+
69+
assert endpoint.headers == {}
70+
71+
# Test serialization with empty headers
72+
json_data = endpoint.to_json()
73+
parsed = json.loads(json_data)
74+
75+
assert parsed["headers"] == {}
76+
77+
def test_endpoint_from_dict_with_headers(self):
78+
"""Test that Endpoint.from_dict works correctly with dict headers"""
79+
data = {
80+
"url": "https://api.example.com/v1/test",
81+
"method": "POST",
82+
"headers": {"authorization": "Bearer token", "x-custom": "value"}
83+
}
84+
85+
endpoint = Endpoint.from_dict(data)
86+
87+
assert endpoint.url == "https://api.example.com/v1/test"
88+
assert endpoint.method == "POST"
89+
assert endpoint.headers == {"authorization": "Bearer token", "x-custom": "value"}
90+
91+
def test_endpoint_aws_polly_use_case(self):
92+
"""Test the specific AWS Polly use case from the bug report"""
93+
endpoint = Endpoint(
94+
url="https://polly.ap-northeast-1.amazonaws.com/v1/speech",
95+
headers={"authorization": "Bearer token"}
96+
)
97+
98+
# Test that it matches the API specification format
99+
json_data = endpoint.to_json()
100+
parsed = json.loads(json_data)
101+
102+
expected_format = {
103+
"url": "https://polly.ap-northeast-1.amazonaws.com/v1/speech",
104+
"method": "POST",
105+
"headers": {
106+
"authorization": "Bearer token"
107+
}
108+
}
109+
110+
assert parsed == expected_format
111+
112+
113+
class TestFunctionHeaders:
114+
"""Unit tests for Function.headers functionality using dictionary format"""
115+
116+
def test_function_headers_dict_format(self):
117+
"""Test that Function accepts headers as a dictionary"""
118+
headers = {"authorization": "Bearer token", "content-type": "application/json"}
119+
function = Function(
120+
name="test_function",
121+
description="Test function",
122+
url="https://api.example.com/v1/function",
123+
method="POST",
124+
headers=headers
125+
)
126+
127+
assert function.headers == headers
128+
assert function.headers["authorization"] == "Bearer token"
129+
130+
def test_function_headers_serialization(self):
131+
"""Test that Function with dict headers serializes correctly to JSON"""
132+
headers = {"authorization": "Bearer token"}
133+
function = Function(
134+
name="test_function",
135+
description="Test function",
136+
url="https://api.example.com/v1/function",
137+
method="POST",
138+
headers=headers
139+
)
140+
141+
json_data = function.to_json()
142+
parsed = json.loads(json_data)
143+
144+
assert parsed["headers"] == headers
145+
assert parsed["name"] == "test_function"
146+
147+
def test_function_from_dict_with_headers(self):
148+
"""Test that Function.from_dict works correctly with dict headers"""
149+
data = {
150+
"name": "test_function",
151+
"description": "Test function",
152+
"url": "https://api.example.com/v1/function",
153+
"method": "POST",
154+
"headers": {"authorization": "Bearer token", "x-custom": "value"}
155+
}
156+
157+
function = Function.from_dict(data)
158+
159+
assert function.name == "test_function"
160+
assert function.headers == {"authorization": "Bearer token", "x-custom": "value"}
161+
162+
163+
class TestSettingsOptionsWithEndpoint:
164+
"""Test SettingsOptions with Endpoint containing headers"""
165+
166+
def test_settings_options_with_endpoint_headers(self):
167+
"""Test full SettingsOptions with speak endpoint headers"""
168+
options = SettingsOptions()
169+
170+
# Configure AWS Polly example from bug report
171+
options.agent.speak.provider.type = "aws_polly"
172+
options.agent.speak.provider.language_code = "en-US"
173+
options.agent.speak.provider.voice = "Matthew"
174+
options.agent.speak.provider.engine = "standard"
175+
options.agent.speak.endpoint = Endpoint(
176+
url="https://polly.ap-northeast-1.amazonaws.com/v1/speech",
177+
headers={"authorization": "Bearer token"}
178+
)
179+
180+
# Test serialization
181+
json_data = options.to_json()
182+
parsed = json.loads(json_data)
183+
184+
# Verify the endpoint headers are in the correct format
185+
speak_endpoint = parsed["agent"]["speak"]["endpoint"]
186+
assert speak_endpoint["url"] == "https://polly.ap-northeast-1.amazonaws.com/v1/speech"
187+
assert speak_endpoint["headers"] == {"authorization": "Bearer token"}
188+
189+
def test_settings_options_multiple_header_values(self):
190+
"""Test endpoint with multiple header values"""
191+
options = SettingsOptions()
192+
193+
headers = {
194+
"authorization": "Bearer token",
195+
"content-type": "application/json",
196+
"x-custom-header": "custom-value"
197+
}
198+
199+
options.agent.speak.endpoint = Endpoint(
200+
url="https://api.example.com/v1/speech",
201+
headers=headers
202+
)
203+
204+
json_data = options.to_json()
205+
parsed = json.loads(json_data)
206+
207+
endpoint_headers = parsed["agent"]["speak"]["endpoint"]["headers"]
208+
assert endpoint_headers == headers
209+
assert len(endpoint_headers) == 3
210+
211+
def test_settings_options_think_endpoint_headers(self):
212+
"""Test think endpoint with headers"""
213+
options = SettingsOptions()
214+
215+
options.agent.think.endpoint = Endpoint(
216+
url="https://api.openai.com/v1/chat/completions",
217+
headers={"authorization": "Bearer sk-..."}
218+
)
219+
220+
json_data = options.to_json()
221+
parsed = json.loads(json_data)
222+
223+
think_endpoint = parsed["agent"]["think"]["endpoint"]
224+
assert think_endpoint["headers"] == {"authorization": "Bearer sk-..."}
225+
226+
227+
class TestBackwardCompatibility:
228+
"""Test backward compatibility with Header class"""
229+
230+
def test_header_class_still_exists(self):
231+
"""Test that Header class still exists for backward compatibility"""
232+
header = Header(key="authorization", value="Bearer token")
233+
assert header.key == "authorization"
234+
assert header.value == "Bearer token"
235+
236+
def test_header_serialization(self):
237+
"""Test that Header still serializes correctly"""
238+
header = Header(key="authorization", value="Bearer token")
239+
json_data = header.to_json()
240+
parsed = json.loads(json_data)
241+
242+
assert parsed["key"] == "authorization"
243+
assert parsed["value"] == "Bearer token"
244+
245+
246+
class TestErrorHandling:
247+
"""Test error handling and edge cases"""
248+
249+
def test_endpoint_headers_with_non_string_values(self):
250+
"""Test behavior with non-string header values"""
251+
# Test that non-string values are handled appropriately
252+
endpoint = Endpoint(
253+
url="https://api.example.com/v1/test",
254+
headers={"authorization": "Bearer token", "timeout": "30"} # Should be strings
255+
)
256+
257+
assert endpoint.headers["timeout"] == "30"
258+
259+
# Test serialization
260+
json_data = endpoint.to_json()
261+
parsed = json.loads(json_data)
262+
assert parsed["headers"]["timeout"] == "30"
263+
264+
265+
# Integration test with properly mocked WebSocket client
266+
class TestIntegrationWithAgentClient:
267+
"""Integration test with the agent websocket client"""
268+
269+
@patch('websockets.sync.client.connect')
270+
def test_endpoint_headers_integration(self, mock_connect):
271+
"""Test that headers work correctly in integration with agent client"""
272+
# Mock the websocket connection to avoid real connections
273+
mock_websocket = MagicMock()
274+
mock_websocket.send.return_value = None
275+
mock_websocket.recv.return_value = '{"type": "Welcome"}'
276+
mock_connect.return_value = mock_websocket
277+
278+
client = DeepgramClient("fake-key")
279+
connection = client.agent.websocket.v("1")
280+
281+
options = SettingsOptions()
282+
options.agent.speak.endpoint = Endpoint(
283+
url="https://polly.ap-northeast-1.amazonaws.com/v1/speech",
284+
headers={"authorization": "Bearer token"}
285+
)
286+
287+
# Test that the options serialize correctly without making real connections
288+
options_json = options.to_json()
289+
parsed = json.loads(options_json)
290+
291+
# Verify the headers are in the correct format in the serialized options
292+
speak_endpoint = parsed["agent"]["speak"]["endpoint"]
293+
assert speak_endpoint["headers"] == {"authorization": "Bearer token"}
294+
assert speak_endpoint["url"] == "https://polly.ap-northeast-1.amazonaws.com/v1/speech"
295+
296+
# Test that the Endpoint can be reconstructed from the JSON
297+
reconstructed_endpoint = Endpoint.from_dict(speak_endpoint)
298+
assert reconstructed_endpoint.headers == {"authorization": "Bearer token"}
299+
assert reconstructed_endpoint.url == "https://polly.ap-northeast-1.amazonaws.com/v1/speech"

0 commit comments

Comments
 (0)