Skip to content

Commit a2d66fe

Browse files
authored
Merge pull request #540 from deepgram/fix/empty
fix: handle empty objects for providers
2 parents 03bc7e3 + 9880f83 commit a2d66fe

File tree

3 files changed

+77
-11
lines changed

3 files changed

+77
-11
lines changed

deepgram/clients/agent/v1/websocket/options.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,19 @@ def __getitem__(self, key):
6868
_dict["properties"] = _dict["properties"].copy()
6969
return _dict[key]
7070

71+
7172
class Provider(dict):
7273
"""
7374
Generic attribute class for provider objects.
7475
"""
76+
7577
def __getattr__(self, name):
7678
try:
7779
return self[name]
7880
except KeyError:
7981
# pylint: disable=raise-missing-from
8082
raise AttributeError(name)
83+
8184
def __setattr__(self, name, value):
8285
self[name] = value
8386

@@ -140,7 +143,16 @@ class Think(BaseResponse):
140143
This class defines any configuration settings for the Think model.
141144
"""
142145

143-
provider: Provider = field(default_factory=Provider)
146+
provider: Provider = field(
147+
default_factory=Provider,
148+
metadata=dataclass_config(
149+
exclude=lambda f: (
150+
f is None
151+
or (isinstance(f, dict) and not f)
152+
or (isinstance(f, Provider) and not f)
153+
)
154+
),
155+
)
144156
functions: Optional[List[Function]] = field(
145157
default=None, metadata=dataclass_config(exclude=lambda f: f is None)
146158
)
@@ -155,7 +167,11 @@ class Think(BaseResponse):
155167
)
156168

157169
def __post_init__(self):
158-
if not isinstance(self.provider, Provider):
170+
if (
171+
not isinstance(self.provider, Provider)
172+
and self.provider is not None
173+
and not (isinstance(self.provider, dict) and not self.provider)
174+
):
159175
self.provider = Provider(self.provider)
160176

161177
def __getitem__(self, key):
@@ -175,10 +191,23 @@ class Listen(BaseResponse):
175191
This class defines any configuration settings for the Listen model.
176192
"""
177193

178-
provider: Provider = field(default_factory=Provider)
194+
provider: Provider = field(
195+
default_factory=Provider,
196+
metadata=dataclass_config(
197+
exclude=lambda f: (
198+
f is None
199+
or (isinstance(f, dict) and not f)
200+
or (isinstance(f, Provider) and not f)
201+
)
202+
),
203+
)
179204

180205
def __post_init__(self):
181-
if not isinstance(self.provider, Provider):
206+
if (
207+
not isinstance(self.provider, Provider)
208+
and self.provider is not None
209+
and not (isinstance(self.provider, dict) and not self.provider)
210+
):
182211
self.provider = Provider(self.provider)
183212

184213
def __getitem__(self, key):
@@ -192,13 +221,26 @@ class Speak(BaseResponse):
192221
This class defines any configuration settings for the Speak model.
193222
"""
194223

195-
provider: Provider = field(default_factory=Provider)
224+
provider: Provider = field(
225+
default_factory=Provider,
226+
metadata=dataclass_config(
227+
exclude=lambda f: (
228+
f is None
229+
or (isinstance(f, dict) and not f)
230+
or (isinstance(f, Provider) and not f)
231+
)
232+
),
233+
)
196234
endpoint: Optional[Endpoint] = field(
197235
default=None, metadata=dataclass_config(exclude=lambda f: f is None)
198236
)
199237

200238
def __post_init__(self):
201-
if not isinstance(self.provider, Provider):
239+
if (
240+
not isinstance(self.provider, Provider)
241+
and self.provider is not None
242+
and not (isinstance(self.provider, dict) and not self.provider)
243+
):
202244
self.provider = Provider(self.provider)
203245

204246
def __getitem__(self, key):
@@ -215,9 +257,30 @@ class Agent(BaseResponse):
215257
"""
216258

217259
language: str = field(default="en")
218-
listen: Listen = field(default_factory=Listen)
219-
think: Think = field(default_factory=Think)
220-
speak: Speak = field(default_factory=Speak)
260+
listen: Listen = field(
261+
default_factory=Listen,
262+
metadata=dataclass_config(
263+
exclude=lambda f: f is None
264+
or (isinstance(f, dict) and not f)
265+
or (isinstance(f, Listen) and not f)
266+
),
267+
)
268+
think: Think = field(
269+
default_factory=Think,
270+
metadata=dataclass_config(
271+
exclude=lambda f: f is None
272+
or (isinstance(f, dict) and not f)
273+
or (isinstance(f, Think) and not f)
274+
),
275+
)
276+
speak: Speak = field(
277+
default_factory=Speak,
278+
metadata=dataclass_config(
279+
exclude=lambda f: f is None
280+
or (isinstance(f, dict) and not f)
281+
or (isinstance(f, Speak) and not f)
282+
),
283+
)
221284
greeting: Optional[str] = field(
222285
default=None, metadata=dataclass_config(exclude=lambda f: f is None)
223286
)
@@ -231,6 +294,8 @@ def __getitem__(self, key):
231294
if "speak" in _dict and isinstance(_dict["speak"], dict):
232295
_dict["speak"] = Speak.from_dict(_dict["speak"])
233296
return _dict[key]
297+
298+
234299
@dataclass
235300
class Input(BaseResponse):
236301
"""
@@ -272,6 +337,7 @@ def __getitem__(self, key):
272337
_dict["output"] = Output.from_dict(_dict["output"])
273338
return _dict[key]
274339

340+
275341
@dataclass
276342
class SettingsOptions(BaseResponse):
277343
"""

examples/agent/no_mic/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def main():
5656
options.agent.think.provider.model = "gpt-4o-mini"
5757
options.agent.think.prompt = "You are a friendly AI assistant."
5858
options.agent.speak.provider.type = "deepgram"
59-
options.agent.speak.model = "aura-2-thalia-en"
59+
options.agent.speak.provider.model = "aura-2-thalia-en"
6060
options.agent.greeting = "Hello! How can I help you today?"
6161

6262
# Send Keep Alive messages

examples/agent/simple/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def main():
4242
"microphone_record": "true",
4343
"speaker_playback": "true",
4444
},
45-
# verbose=verboselogs.DEBUG,
45+
verbose=verboselogs.SPAM,
4646
)
4747
print("Created DeepgramClientOptions...")
4848

0 commit comments

Comments
 (0)