Skip to content

Commit 9060b89

Browse files
authored
fix: Weaviate - fix types + add py.typed (#1977)
1 parent 1c9e06f commit 9060b89

File tree

8 files changed

+38
-32
lines changed

8 files changed

+38
-32
lines changed

.github/workflows/weaviate.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,9 @@ jobs:
4444
- name: Install Hatch
4545
run: pip install --upgrade hatch
4646

47-
# TODO: Once this integration is properly typed, use hatch run test:types
48-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
4947
- name: Lint
5048
if: matrix.python-version == '3.9' && runner.os == 'Linux'
51-
run: hatch run fmt-check && hatch run lint:typing
49+
run: hatch run fmt-check && hatch run test:types
5250

5351
- name: Run Weaviate container
5452
run: docker compose up -d

integrations/weaviate/pyproject.toml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,18 +70,14 @@ integration = 'pytest -m "integration" {args:tests}'
7070
all = 'pytest {args:tests}'
7171
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
7272

73-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
74-
75-
# TODO: remove lint environment once this integration is properly typed
76-
# test environment should be used instead
77-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
78-
[tool.hatch.envs.lint]
79-
installer = "uv"
80-
detached = true
81-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
82-
83-
[tool.hatch.envs.lint.scripts]
84-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
73+
types = """mypy -p haystack_integrations.document_stores.weaviate \
74+
-p haystack_integrations.components.retrievers.weaviate {args}"""
75+
76+
[tool.mypy]
77+
install_types = true
78+
non_interactive = true
79+
check_untyped_defs = true
80+
disallow_incomplete_defs = true
8581

8682
[tool.black]
8783
target-version = ["py38"]

integrations/weaviate/src/haystack_integrations/components/retrievers/py.typed

Whitespace-only changes.

integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/bm25_retriever.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever":
8989
return default_from_dict(cls, data)
9090

9191
@component.output_types(documents=List[Document])
92-
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
92+
def run(
93+
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
94+
) -> Dict[str, List[Document]]:
9395
"""
9496
Retrieves documents from Weaviate using the BM25 algorithm.
9597

integrations/weaviate/src/haystack_integrations/components/retrievers/weaviate/embedding_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def run(
106106
top_k: Optional[int] = None,
107107
distance: Optional[float] = None,
108108
certainty: Optional[float] = None,
109-
):
109+
) -> Dict[str, List[Document]]:
110110
"""
111111
Retrieves documents from Weaviate using the vector search.
112112

integrations/weaviate/src/haystack_integrations/document_stores/py.typed

Whitespace-only changes.

integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __str__(self):
3030
return self.value
3131

3232
@staticmethod
33-
def from_class(auth_class) -> "SupportedAuthTypes":
33+
def from_class(auth_class: Type["AuthCredentials"]) -> "SupportedAuthTypes":
3434
auth_types = {
3535
AuthApiKey: SupportedAuthTypes.API_KEY,
3636
AuthBearerToken: SupportedAuthTypes.BEARER,
@@ -80,7 +80,7 @@ def from_dict(data: Dict[str, Any]) -> "AuthCredentials":
8080

8181
@classmethod
8282
@abstractmethod
83-
def _from_dict(cls, data: Dict[str, Any]):
83+
def _from_dict(cls, data: Dict[str, Any]) -> "AuthCredentials":
8484
"""
8585
Internal method to convert a dictionary representation to an auth credentials object.
8686
All subclasses must implement this method.
@@ -109,7 +109,8 @@ def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey":
109109
return cls(**data["init_parameters"])
110110

111111
def resolve_value(self) -> WeaviateAuthApiKey:
112-
return WeaviateAuthApiKey(api_key=self.api_key.resolve_value())
112+
# resolve_value, when used with Secret.from_env_var (strict=True), returns a string or raises an error
113+
return WeaviateAuthApiKey(api_key=self.api_key.resolve_value()) # type: ignore[arg-type]
113114

114115

115116
@dataclass(frozen=True)
@@ -136,9 +137,11 @@ def resolve_value(self) -> WeaviateAuthBearerToken:
136137
refresh_token = self.refresh_token.resolve_value()
137138

138139
return WeaviateAuthBearerToken(
139-
access_token=access_token,
140+
# resolve_value, when used with Secret.from_env_var (strict=True), returns a string or raises an error
141+
access_token=access_token, # type: ignore[arg-type]
140142
expires_in=self.expires_in,
141-
refresh_token=refresh_token,
143+
# resolve_value, when used with Secret.from_env_var (strict=False), returns a string or None
144+
refresh_token=refresh_token, # type: ignore[arg-type]
142145
)
143146

144147

@@ -162,8 +165,10 @@ def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials":
162165

163166
def resolve_value(self) -> WeaviateAuthClientCredentials:
164167
return WeaviateAuthClientCredentials(
165-
client_secret=self.client_secret.resolve_value(),
166-
scope=self.scope.resolve_value(),
168+
# resolve_value, when used with Secret.from_env_var (strict=True), returns a string or raises an error
169+
client_secret=self.client_secret.resolve_value(), # type: ignore[arg-type]
170+
# resolve_value, when used with Secret.from_env_var (strict=False), returns a string or None
171+
scope=self.scope.resolve_value(), # type: ignore[arg-type]
167172
)
168173

169174

@@ -189,7 +194,9 @@ def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword":
189194

190195
def resolve_value(self) -> WeaviateAuthClientPassword:
191196
return WeaviateAuthClientPassword(
192-
username=self.username.resolve_value(),
193-
password=self.password.resolve_value(),
194-
scope=self.scope.resolve_value(),
197+
# resolve_value, when used with Secret.from_env_var (strict=True), returns a string or raises an error
198+
username=self.username.resolve_value(), # type: ignore[arg-type]
199+
password=self.password.resolve_value(), # type: ignore[arg-type]
200+
# resolve_value, when used with Secret.from_env_var (strict=False), returns a string or None
201+
scope=self.scope.resolve_value(), # type: ignore[arg-type]
195202
)

integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def __init__(
144144
self._additional_config = additional_config
145145
self._grpc_port = grpc_port
146146
self._grpc_secure = grpc_secure
147-
self._client = None
148-
self._collection = None
147+
self._client: Optional[weaviate.WeaviateClient] = None
148+
self._collection: Optional[weaviate.Collection] = None
149149
# Store the connection settings dictionary
150150
self._collection_settings = collection_settings or {
151151
"class": "Default",
@@ -173,9 +173,12 @@ def client(self):
173173
# If we detect that the URL is a Weaviate Cloud URL, we use the utility function to connect
174174
# instead of using WeaviateClient directly like in other cases.
175175
# Among other things, the utility function takes care of parsing the URL.
176+
if not self._auth_client_secret:
177+
msg = "Auth credentials are required for Weaviate Cloud Services"
178+
raise ValueError(msg)
176179
self._client = weaviate.connect_to_weaviate_cloud(
177180
self._url,
178-
auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None,
181+
auth_credentials=self._auth_client_secret.resolve_value(),
179182
headers=self._additional_headers,
180183
additional_config=self._additional_config,
181184
)
@@ -343,7 +346,7 @@ def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document:
343346

344347
return Document.from_dict(document_data)
345348

346-
def _query(self) -> List[Dict[str, Any]]:
349+
def _query(self) -> List[DataObject[Dict[str, Any], None]]:
347350
properties = [p.name for p in self.collection.config.get().properties]
348351
try:
349352
result = self.collection.iterator(include_vector=True, return_properties=properties)
@@ -352,7 +355,7 @@ def _query(self) -> List[Dict[str, Any]]:
352355
raise DocumentStoreError(msg) from e
353356
return result
354357

355-
def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]:
358+
def _query_with_filters(self, filters: Dict[str, Any]) -> List[DataObject[Dict[str, Any], None]]:
356359
properties = [p.name for p in self.collection.config.get().properties]
357360
# When querying with filters we need to paginate using limit and offset as using
358361
# a cursor with after is not possible. See the official docs:

0 commit comments

Comments
 (0)