6
6
import pytest
7
7
8
8
from core import (
9
+ AppConfig ,
9
10
AuthConfig ,
10
11
BCryptConfig ,
11
12
CompletionConfig ,
@@ -59,32 +60,35 @@ def generate_random_vector_entry(
59
60
generate_random_vector_entry (i , dimension ) for i in range (num_entries )
60
61
]
61
62
63
+ @pytest .fixture (scope = "session" )
64
+ def app_config ():
65
+ return AppConfig ()
62
66
63
67
# Crypto
64
68
@pytest .fixture (scope = "session" )
65
- def crypto_config ():
66
- return BCryptConfig ()
69
+ def crypto_config (app_config ):
70
+ return BCryptConfig (app = app_config )
67
71
68
72
69
73
@pytest .fixture (scope = "session" )
70
- def crypto_provider (crypto_config ):
74
+ def crypto_provider (crypto_config , app_config ):
71
75
return BCryptProvider (crypto_config )
72
76
73
77
74
78
# Postgres
75
79
@pytest .fixture (scope = "session" )
76
- def db_config ():
80
+ def db_config (app_config ):
77
81
collection_id = uuid .uuid4 ()
78
82
79
83
random_project_name = f"test_collection_{ collection_id .hex } "
80
84
return DatabaseConfig .create (
81
- provider = "postgres" , project_name = random_project_name
85
+ provider = "postgres" , project_name = random_project_name , app = app_config
82
86
)
83
87
84
88
85
89
@pytest .fixture (scope = "function" )
86
90
async def postgres_db_provider (
87
- db_config , dimension , crypto_provider , sample_entries
91
+ db_config , dimension , crypto_provider , sample_entries , app_config
88
92
):
89
93
db = PostgresDBProvider (
90
94
db_config , dimension = dimension , crypto_provider = crypto_provider
@@ -98,12 +102,12 @@ async def postgres_db_provider(
98
102
99
103
100
104
@pytest .fixture (scope = "function" )
101
- def db_config_temporary ():
105
+ def db_config_temporary (app_config ):
102
106
collection_id = uuid .uuid4 ()
103
107
104
108
random_project_name = f"test_collection_{ collection_id .hex } "
105
109
return DatabaseConfig .create (
106
- provider = "postgres" , project_name = random_project_name
110
+ provider = "postgres" , project_name = random_project_name , app = app_config
107
111
)
108
112
109
113
@@ -127,12 +131,13 @@ async def temporary_postgres_db_provider(
127
131
128
132
# Auth
129
133
@pytest .fixture (scope = "session" )
130
- def auth_config ():
134
+ def auth_config (app_config ):
131
135
return AuthConfig (
132
136
secret_key = "test_secret_key" ,
133
137
access_token_lifetime_in_minutes = 15 ,
134
138
refresh_token_lifetime_in_days = 1 ,
135
139
require_email_verification = False ,
140
+ app = app_config
136
141
)
137
142
138
143
@@ -149,19 +154,20 @@ async def r2r_auth_provider(
149
154
150
155
# Embeddings
151
156
@pytest .fixture
152
- def litellm_provider ():
157
+ def litellm_provider (app_config ):
153
158
config = EmbeddingConfig (
154
159
provider = "litellm" ,
155
160
base_model = "text-embedding-3-small" ,
156
161
base_dimension = 1536 ,
162
+ app = app_config
157
163
)
158
164
return LiteLLMEmbeddingProvider (config )
159
165
160
166
161
167
# File Provider
162
168
@pytest .fixture (scope = "function" )
163
- def file_config ():
164
- return FileConfig (provider = "postgres" )
169
+ def file_config (app_config ):
170
+ return FileConfig (provider = "postgres" , app = app_config )
165
171
166
172
167
173
@pytest .fixture (scope = "function" )
@@ -176,18 +182,18 @@ async def postgres_file_provider(file_config, temporary_postgres_db_provider):
176
182
177
183
# LLM provider
178
184
@pytest .fixture
179
- def litellm_completion_provider ():
180
- config = CompletionConfig (provider = "litellm" )
185
+ def litellm_completion_provider (app_config ):
186
+ config = CompletionConfig (provider = "litellm" , app = app_config )
181
187
return LiteCompletionProvider (config )
182
188
183
189
184
190
# Logging
185
191
@pytest .fixture (scope = "function" )
186
- async def local_logging_provider ():
192
+ async def local_logging_provider (app_config ):
187
193
unique_id = str (uuid .uuid4 ())
188
194
logging_path = f"test_{ unique_id } .sqlite"
189
195
provider = LocalRunLoggingProvider (
190
- LoggingConfig (logging_path = logging_path )
196
+ LoggingConfig (logging_path = logging_path , app = app_config )
191
197
)
192
198
await provider ._init ()
193
199
yield provider
0 commit comments