Skip to content

Commit d2e0298

Browse files
authored
Simplify weaviate auth (#223)
* add function to check host Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * fix bug in test Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * simplify auth build logic Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * remove useless en vars Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * remove useless env var Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * add default url Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * remove todo Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * code cleanup Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * code cleanup Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * Update README Signed-off-by: hsm207 <hsm207@users.noreply.github.com> * Remove batch config env vars * fix regex to also check for WCS enterprise cluster Signed-off-by: hsm207 <hsm207@users.noreply.github.com> --------- Signed-off-by: hsm207 <hsm207@users.noreply.github.com>
1 parent 0ebb015 commit d2e0298

File tree

4 files changed

+94
-95
lines changed

4 files changed

+94
-95
lines changed

README.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,9 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
8888
export PINECONE_INDEX=<your_pinecone_index>
8989
9090
# Weaviate
91-
export WEAVIATE_HOST=<your_weaviate_instance_url>
92-
export WEAVIATE_PORT=<your_weaviate_port_443_for_WCS>
91+
export WEAVIATE_URL=<your_weaviate_instance_url>
92+
export WEAVIATE_API_KEY=<your_api_key_for_WCS>
9393
export WEAVIATE_CLASS=<your_optional_weaviate_class>
94-
export WEAVIATE_USERNAME=<your_weaviate_WCS_username>
95-
export WEAVIATE_PASSWORD=<your_weaviate_WCS_password>
96-
export WEAVIATE_SCOPES=<your_optional_weaviate_scopes>
97-
export WEAVIATE_BATCH_SIZE=<optional_weaviate_batch_size>
98-
export WEAVIATE_BATCH_DYNAMIC=<optional_weaviate_batch_dynamic>
99-
export WEAVIATE_BATCH_TIMEOUT_RETRIES=<optional_weaviate_batch_timeout_retries>
100-
export WEAVIATE_BATCH_NUM_WORKERS=<optional_weaviate_batch_num_workers>
10194
10295
# Zilliz
10396
export ZILLIZ_COLLECTION=<your_zilliz_collection>

datastore/providers/weaviate_datastore.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,26 @@
1-
# TODO
21
import asyncio
3-
from typing import Dict, List, Optional
4-
from loguru import logger
5-
from weaviate import Client
6-
import weaviate
72
import os
3+
import re
84
import uuid
5+
from typing import Dict, List, Optional
96

7+
import weaviate
8+
from loguru import logger
9+
from weaviate import Client
1010
from weaviate.util import generate_uuid5
1111

1212
from datastore.datastore import DataStore
1313
from models.models import (
1414
DocumentChunk,
1515
DocumentChunkMetadata,
16+
DocumentChunkWithScore,
1617
DocumentMetadataFilter,
1718
QueryResult,
1819
QueryWithEmbedding,
19-
DocumentChunkWithScore,
2020
Source,
2121
)
2222

23-
24-
WEAVIATE_HOST = os.environ.get("WEAVIATE_HOST", "http://127.0.0.1")
25-
WEAVIATE_PORT = os.environ.get("WEAVIATE_PORT", "8080")
26-
WEAVIATE_USERNAME = os.environ.get("WEAVIATE_USERNAME", None)
27-
WEAVIATE_PASSWORD = os.environ.get("WEAVIATE_PASSWORD", None)
28-
WEAVIATE_SCOPES = os.environ.get("WEAVIATE_SCOPES", "offline_access")
23+
WEAVIATE_URL_DEFAULT = "http://localhost:8080"
2924
WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "OpenAIDocument")
3025

3126
WEAVIATE_BATCH_SIZE = int(os.environ.get("WEAVIATE_BATCH_SIZE", 20))
@@ -109,7 +104,7 @@ def handle_errors(self, results: Optional[List[dict]]) -> List[str]:
109104
def __init__(self):
110105
auth_credentials = self._build_auth_credentials()
111106

112-
url = f"{WEAVIATE_HOST}:{WEAVIATE_PORT}"
107+
url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT)
113108

114109
logger.debug(
115110
f"Connecting to weaviate instance at {url} with credential type {type(auth_credentials).__name__}"
@@ -140,10 +135,14 @@ def __init__(self):
140135

141136
@staticmethod
142137
def _build_auth_credentials():
143-
if WEAVIATE_USERNAME and WEAVIATE_PASSWORD:
144-
return weaviate.auth.AuthClientPassword(
145-
WEAVIATE_USERNAME, WEAVIATE_PASSWORD, WEAVIATE_SCOPES
146-
)
138+
url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT)
139+
140+
if WeaviateDataStore._is_wcs_domain(url):
141+
api_key = os.environ.get("WEAVIATE_API_KEY")
142+
if api_key is not None:
143+
return weaviate.auth.AuthApiKey(api_key=api_key)
144+
else:
145+
raise ValueError("WEAVIATE_API_KEY environment variable is not set")
147146
else:
148147
return None
149148

@@ -370,3 +369,17 @@ def _is_valid_weaviate_id(candidate_id: str) -> bool:
370369
return True
371370
except ValueError:
372371
return False
372+
373+
@staticmethod
374+
def _is_wcs_domain(url: str) -> bool:
375+
"""
376+
Check if the given URL ends with ".weaviate.network" or ".weaviate.network/".
377+
378+
Args:
379+
url (str): The URL to check.
380+
381+
Returns:
382+
bool: True if the URL ends with the specified strings, False otherwise.
383+
"""
384+
pattern = r"\.(weaviate\.cloud|weaviate\.network)(/)?$"
385+
return bool(re.search(pattern, url))

docs/providers/weaviate/setup.md

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,39 +65,15 @@ You need to set some environment variables to connect to your Weaviate instance.
6565
6666
| Name | Required | Description | Default |
6767
|------------------| -------- | ------------------------------------------------------------------ | ------------------ |
68-
| `WEAVIATE_HOST` | Optional | Your Weaviate instance host address (see notes below) | `http://127.0.0.1` |
69-
| `WEAVIATE_PORT` | Optional | Your Weaviate port number (use 443 for WCS) | 8080 |
68+
| `WEAVIATE_URL` | Optional | Your weaviate instance's url/WCS endpoint | `http://localhost:8080` | |
7069
| `WEAVIATE_CLASS` | Optional | Your chosen Weaviate class/collection name to store your documents | OpenAIDocument |
7170
72-
> For **WCS instances**, set `WEAVIATE_PORT` to 443 and `WEAVIATE_HOST` to `https://(wcs-instance-name).weaviate.network`. For example: `https://my-project.weaviate.network/`.
73-
74-
> For **self-hosted instances**, if your instance is not at 127.0.0.1:8080, set `WEAVIATE_HOST` and `WEAVIATE_PORT` accordingly. For example: `WEAVIATE_HOST=http://localhost/` and `WEAVIATE_PORT=4040`.
75-
7671
**Weaviate Auth Environment Variables**
7772
78-
If you enabled OIDC authentication for your Weaviate instance (recommended for WCS instances), set the following environment variables. If you enabled anonymous access, skip this section.
73+
If using WCS instances, set the following environment variables:
7974
8075
| Name | Required | Description |
8176
| ------------------- | -------- | ------------------------------ |
82-
| `WEAVIATE_USERNAME` | Yes | Your OIDC or WCS username |
83-
| `WEAVIATE_PASSWORD` | Yes | Your OIDC or WCS password |
84-
| `WEAVIATE_SCOPES` | Optional | Space-separated list of scopes |
85-
86-
Learn more about [authentication in Weaviate](https://weaviate.io/developers/weaviate/configuration/authentication#overview) and the [Python client authentication](https://weaviate-python-client.readthedocs.io/en/stable/weaviate.auth.html).
87-
88-
**Weaviate Batch Import Environment Variables**
89-
90-
Weaviate uses a batching mechanism to perform operations in bulk. This makes importing and updating your data faster and more efficient. You can adjust the batch settings with these optional environment variables:
91-
92-
| Name | Required | Description | Default |
93-
| -------------------------------- | -------- | ------------------------------------------------------------ | ------- |
94-
| `WEAVIATE_BATCH_SIZE` | Optional | Number of insert/updates per batch operation | 20 |
95-
| `WEAVIATE_BATCH_DYNAMIC` | Optional | Lets the batch process decide the batch size | False |
96-
| `WEAVIATE_BATCH_TIMEOUT_RETRIES` | Optional | Number of retry-on-timeout attempts | 3 |
97-
| `WEAVIATE_BATCH_NUM_WORKERS` | Optional | The max number of concurrent threads to run batch operations | 1 |
98-
99-
> **Note:** The optimal `WEAVIATE_BATCH_SIZE` depends on the available resources (RAM, CPU). A higher value means faster bulk operations, but also higher demand for RAM and CPU. If you experience failures during the import process, reduce the batch size.
100-
101-
> Setting `WEAVIATE_BATCH_SIZE` to `None` means no limit to the batch size. All insert or update operations would be sent to Weaviate in a single operation. This might be risky, as you lose control over the batch size.
77+
| `WEAVIATE_API_KEY` | Yes | Your API key WCS |
10278
103-
Learn more about [batch configuration in Weaviate](https://weaviate.io/developers/weaviate/client-libraries/python#batch-configuration).
79+
Learn more about accessing your [WCS API key](https://weaviate.io/developers/wcs/guides/authentication#access-api-keys).

tests/datastore/providers/weaviate/test_weaviate_datastore.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
import logging
2+
import os
3+
14
import pytest
5+
import weaviate
6+
from _pytest.logging import LogCaptureFixture
27
from fastapi.testclient import TestClient
8+
from loguru import logger
39
from weaviate import Client
4-
import weaviate
5-
import os
6-
from models.models import DocumentMetadataFilter, Source
7-
from server.main import app
10+
811
from datastore.providers.weaviate_datastore import (
912
SCHEMA,
1013
WeaviateDataStore,
1114
extract_schema_properties,
1215
)
13-
import logging
14-
from loguru import logger
15-
from _pytest.logging import LogCaptureFixture
16+
from models.models import DocumentMetadataFilter, Source
17+
from server.main import app
1618

1719
BEARER_TOKEN = os.getenv("BEARER_TOKEN")
1820

@@ -99,30 +101,6 @@ def documents():
99101
yield documents
100102

101103

102-
@pytest.fixture
103-
def mock_env_public_access(monkeypatch):
104-
monkeypatch.setattr(
105-
"datastore.providers.weaviate_datastore.WEAVIATE_USERNAME", None
106-
)
107-
monkeypatch.setattr(
108-
"datastore.providers.weaviate_datastore.WEAVIATE_PASSWORD", None
109-
)
110-
111-
112-
@pytest.fixture
113-
def mock_env_resource_owner_password_flow(monkeypatch):
114-
monkeypatch.setattr(
115-
"datastore.providers.weaviate_datastore.WEAVIATE_SCOPES",
116-
["schema:read", "schema:write"],
117-
)
118-
monkeypatch.setattr(
119-
"datastore.providers.weaviate_datastore.WEAVIATE_USERNAME", "admin"
120-
)
121-
monkeypatch.setattr(
122-
"datastore.providers.weaviate_datastore.WEAVIATE_PASSWORD", "abc123"
123-
)
124-
125-
126104
@pytest.fixture
127105
def caplog(caplog: LogCaptureFixture):
128106
handler_id = logger.add(caplog.handler, format="{message}")
@@ -337,16 +315,38 @@ def test_delete(test_db, weaviate_client, caplog):
337315
assert not weaviate_client.data_object.get()["objects"]
338316

339317

340-
def test_access_with_username_password(mock_env_resource_owner_password_flow):
341-
auth_credentials = WeaviateDataStore._build_auth_credentials()
342-
343-
assert isinstance(auth_credentials, weaviate.auth.AuthClientPassword)
344-
345-
346-
def test_public_access(mock_env_public_access):
347-
auth_credentials = WeaviateDataStore._build_auth_credentials()
348-
349-
assert auth_credentials is None
318+
def test_build_auth_credentials(monkeypatch):
319+
# Test when WEAVIATE_URL ends with weaviate.network and WEAVIATE_API_KEY is set
320+
with monkeypatch.context() as m:
321+
m.setenv("WEAVIATE_URL", "https://example.weaviate.network")
322+
m.setenv("WEAVIATE_API_KEY", "your_api_key")
323+
auth_credentials = WeaviateDataStore._build_auth_credentials()
324+
assert auth_credentials is not None
325+
assert isinstance(auth_credentials, weaviate.auth.AuthApiKey)
326+
assert auth_credentials.api_key == "your_api_key"
327+
328+
# Test when WEAVIATE_URL ends with weaviate.network and WEAVIATE_API_KEY is not set
329+
with monkeypatch.context() as m:
330+
m.setenv("WEAVIATE_URL", "https://example.weaviate.network")
331+
m.delenv("WEAVIATE_API_KEY", raising=False)
332+
with pytest.raises(
333+
ValueError, match="WEAVIATE_API_KEY environment variable is not set"
334+
):
335+
WeaviateDataStore._build_auth_credentials()
336+
337+
# Test when WEAVIATE_URL does not end with weaviate.network
338+
with monkeypatch.context() as m:
339+
m.setenv("WEAVIATE_URL", "https://example.notweaviate.network")
340+
m.setenv("WEAVIATE_API_KEY", "your_api_key")
341+
auth_credentials = WeaviateDataStore._build_auth_credentials()
342+
assert auth_credentials is None
343+
344+
# Test when WEAVIATE_URL is not set
345+
with monkeypatch.context() as m:
346+
m.delenv("WEAVIATE_URL", raising=False)
347+
m.setenv("WEAVIATE_API_KEY", "your_api_key")
348+
auth_credentials = WeaviateDataStore._build_auth_credentials()
349+
assert auth_credentials is None
350350

351351

352352
def test_extract_schema_properties():
@@ -519,3 +519,20 @@ def build_upsert_payload(document):
519519
# but it is None right now because an
520520
# update function is out of scope
521521
assert weaviate_doc[0]["source"] is None
522+
523+
524+
@pytest.mark.parametrize(
525+
"url, expected_result",
526+
[
527+
("https://example.weaviate.network", True),
528+
("https://example.weaviate.network/", True),
529+
("https://example.weaviate.cloud", True),
530+
("https://example.weaviate.cloud/", True),
531+
("https://example.notweaviate.network", False),
532+
("https://weaviate.network.example.com", False),
533+
("https://example.weaviate.network/somepage", False),
534+
("", False),
535+
],
536+
)
537+
def test_is_wcs_domain(url, expected_result):
538+
assert WeaviateDataStore._is_wcs_domain(url) == expected_result

0 commit comments

Comments
 (0)