Skip to content

Commit

Permalink
add back postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Oct 4, 2024
1 parent 5bebd3b commit 2924dfa
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
1 change: 1 addition & 0 deletions .github/workflows/integration-test-workflow-debian.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
POSTGRES_PORT: 5432
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
R2R_PROJECT_NAME: r2r_default
steps:
- name: Install and configure PostgreSQL
run: |
Expand Down
1 change: 1 addition & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"PipeType",
## PROVIDERS
# Base provider classes
"AppConfig",
"Provider",
"ProviderConfig",
# Auth provider
Expand Down
1 change: 1 addition & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"PipeType",
## PROVIDERS
# Base provider classes
"AppConfig",
"Provider",
"ProviderConfig",
# Auth provider
Expand Down
38 changes: 22 additions & 16 deletions py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from core import (
AppConfig,
AuthConfig,
BCryptConfig,
CompletionConfig,
Expand Down Expand Up @@ -59,32 +60,35 @@ def generate_random_vector_entry(
generate_random_vector_entry(i, dimension) for i in range(num_entries)
]

@pytest.fixture(scope="session")
def app_config():
return AppConfig()

# Crypto
@pytest.fixture(scope="session")
def crypto_config():
return BCryptConfig()
def crypto_config(app_config):
return BCryptConfig(app=app_config)


@pytest.fixture(scope="session")
def crypto_provider(crypto_config):
def crypto_provider(crypto_config, app_config):
return BCryptProvider(crypto_config)


# Postgres
@pytest.fixture(scope="session")
def db_config():
def db_config(app_config):
collection_id = uuid.uuid4()

random_project_name = f"test_collection_{collection_id.hex}"
return DatabaseConfig.create(
provider="postgres", project_name=random_project_name
provider="postgres", project_name=random_project_name, app=app_config
)


@pytest.fixture(scope="function")
async def postgres_db_provider(
db_config, dimension, crypto_provider, sample_entries
db_config, dimension, crypto_provider, sample_entries, app_config
):
db = PostgresDBProvider(
db_config, dimension=dimension, crypto_provider=crypto_provider
Expand All @@ -98,12 +102,12 @@ async def postgres_db_provider(


@pytest.fixture(scope="function")
def db_config_temporary():
def db_config_temporary(app_config):
collection_id = uuid.uuid4()

random_project_name = f"test_collection_{collection_id.hex}"
return DatabaseConfig.create(
provider="postgres", project_name=random_project_name
provider="postgres", project_name=random_project_name, app=app_config
)


Expand All @@ -127,12 +131,13 @@ async def temporary_postgres_db_provider(

# Auth
@pytest.fixture(scope="session")
def auth_config():
def auth_config(app_config):
return AuthConfig(
secret_key="test_secret_key",
access_token_lifetime_in_minutes=15,
refresh_token_lifetime_in_days=1,
require_email_verification=False,
app=app_config
)


Expand All @@ -149,19 +154,20 @@ async def r2r_auth_provider(

# Embeddings
@pytest.fixture
def litellm_provider():
def litellm_provider(app_config):
config = EmbeddingConfig(
provider="litellm",
base_model="text-embedding-3-small",
base_dimension=1536,
app=app_config
)
return LiteLLMEmbeddingProvider(config)


# File Provider
@pytest.fixture(scope="function")
def file_config():
return FileConfig(provider="postgres")
def file_config(app_config):
return FileConfig(provider="postgres", app=app_config)


@pytest.fixture(scope="function")
Expand All @@ -176,18 +182,18 @@ async def postgres_file_provider(file_config, temporary_postgres_db_provider):

# LLM provider
@pytest.fixture
def litellm_completion_provider():
config = CompletionConfig(provider="litellm")
def litellm_completion_provider(app_config):
config = CompletionConfig(provider="litellm", app=app_config)
return LiteCompletionProvider(config)


# Logging
@pytest.fixture(scope="function")
async def local_logging_provider():
async def local_logging_provider(app_config):
unique_id = str(uuid.uuid4())
logging_path = f"test_{unique_id}.sqlite"
provider = LocalRunLoggingProvider(
LoggingConfig(logging_path=logging_path)
LoggingConfig(logging_path=logging_path, app=app_config)
)
await provider._init()
yield provider
Expand Down

0 comments on commit 2924dfa

Please sign in to comment.