Skip to content

Commit 8097016

Browse files
committed
fixed the bug where using the arguments apart from the q argument in the query would return different results. Enabled functionality for saving the polars dataframe as a CSV. Adjusted tests accordingly.
1 parent 4d7a275 commit 8097016

File tree

3 files changed

+84
-31
lines changed

3 files changed

+84
-31
lines changed

industryDocumentsWrapper/ucsf_api.py

+44-21
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,31 @@
11
from dataclasses import dataclass
22
import re
3+
import time
34
import requests
45
import polars as pl
56

67

8+
BATCH_TIMEOUT = 30 # seconds
9+
RATE_LIMIT = 0.1 # seconds between requests
10+
711
@dataclass
812
class IndustryDocsSearch:
913
"""
1014
UCSF Industry Documents Library Solr API Wrapper Class.
1115
1216
API Documentation found here: https://www.industrydocuments.ucsf.edu/wp-content/uploads/2020/08/IndustryDocumentsDataAPI_v7.pdf
1317
"""
14-
base_url = "https://metadata.idl.ucsf.edu/solr/ltdl3/"
15-
results = []
18+
def __init__(self):
19+
self.__base_url = "https://metadata.idl.ucsf.edu/solr/ltdl3/"
20+
self.results = []
1621

1722
def _create_query(self, **kwargs) -> str:
1823
"""Constructs parametrized query"""
1924
if kwargs['q']:
20-
query = f"{self.base_url}query?q=({kwargs['q']})&wt={kwargs['wt']}&cursorMark={kwargs['cursorMark']}&sort={kwargs['sort']}"
25+
query = f"{self.__base_url}query?q=({kwargs['q']})&wt={kwargs['wt']}&cursorMark={kwargs['cursorMark']}&sort={kwargs['sort']}"
2126
else:
22-
query = f"{self.base_url}query?q=("+' AND '.join([f'{k}:{v}' for k, v in kwargs.items() if v and k != 'wt' and k != 'cursorMark' and k != 'sort' and k != 'n'])+f")&wt={kwargs['wt']}&cursorMark={kwargs['cursorMark']}&sort={kwargs['sort']}"
23-
27+
query = f"{self.__base_url}query?q=("+' AND '.join([f'{k}:"{v}"' for k, v in kwargs.items() if v and k != 'wt' and k != 'cursorMark' and k != 'sort' and k != 'n'])+f")&wt={kwargs['wt']}&cursorMark={kwargs['cursorMark']}&sort={kwargs['sort']}"
28+
print(query)
2429
return query
2530

2631
def _update_cursormark(self, query:str, cursor_mark: str) -> str:
@@ -32,29 +37,37 @@ def _loop_results(self, query:str, n:int) -> None:
3237
next_cursor = None
3338
current_cursor = '*' # initial cursor mark
3439

40+
# Get initial response to check total available documents
41+
initial_response = requests.get(query).json()
42+
total_available = initial_response['response']['numFound']
43+
print(f"Total available documents: {total_available}")
44+
45+
if n > total_available:
46+
print(f"Warning: Only {total_available} documents available, which is less than the {n} requested")
47+
n = total_available
48+
3549
if n == -1:
36-
n = float('inf')
50+
n = total_available
3751

3852
while (next_cursor != current_cursor) and (len(self.results) < n):
39-
4053
if next_cursor:
4154
current_cursor = next_cursor
4255
query = self._update_cursormark(query, current_cursor)
4356

44-
r = requests.get(query).json()
57+
r = requests.get(query, timeout=BATCH_TIMEOUT).json()
58+
docs = r['response']['docs']
4559

46-
if n < len(r['response']['docs']):
60+
if n < len(docs):
4761
self.results.extend(r['response']['docs'][:n])
48-
49-
elif n < (len(self.results) + len(r['response']['docs'])):
62+
elif n < (len(self.results) + len(docs)):
5063
self.results.extend(r['response']['docs'][:n-len(self.results)])
51-
5264
else:
53-
self.results.extend(r['response']['docs'])
65+
self.results.extend(docs)
5466

5567
next_cursor = r['nextCursorMark']
5668

5769
print(f"{len(self.results)}/{n} documents collected")
70+
time.sleep(RATE_LIMIT)
5871

5972
return
6073

@@ -67,7 +80,7 @@ def query(self,
6780
q:str = False,
6881
case:str = False,
6982
collection:str = False,
70-
doc_type:str = False,
83+
type:str = False,
7184
industry:str = False,
7285
brand:str = False,
7386
availability:str = False,
@@ -87,7 +100,7 @@ def query(self,
87100
query = self._create_query(q=q,
88101
case=case,
89102
collection=collection,
90-
type=doc_type,
103+
type=type,
91104
industry=industry,
92105
brand=brand,
93106
availability=availability,
@@ -115,9 +128,12 @@ def query(self,
115128
# TODO: Determine whether we need to maintain this load method
116129
def load(self, filename: str) -> pl.DataFrame:
117130
"""Reads results from a local CSV or JSON"""
118-
if not filename.lower().endswith('.parquet'):
119-
raise Exception("Only parquet format supported currently.")
120-
self.results = pl.read_parquet(filename)
131+
if filename.lower().endswith('.json'):
132+
self.results = pl.read_json(filename)
133+
elif filename.lower.endswith('.parquet'):
134+
self.results = pl.read_parquet(filename)
135+
elif filename.lower().endswith('.csv'):
136+
self.results = pl.read_csv(filename)
121137

122138

123139
def save(self, filename: str, format: str) -> None:
@@ -126,9 +142,16 @@ def save(self, filename: str, format: str) -> None:
126142
match format:
127143
case 'parquet':
128144
df.write_parquet(filename)
129-
# case 'csv':
130-
# df = df.with_columns(pl.col(pl.List, pl.Struct, pl.Array).list.join(","))
131-
# df.write_csv(filename)
145+
case 'csv':
146+
nested_cols = df.select([
147+
pl.col(col) for col in df.columns
148+
if pl.DataFrame(df).schema[col] in [pl.List, pl.Struct, pl.Array]
149+
]).columns
150+
if nested_cols:
151+
df = df.with_columns([
152+
pl.col(col).map_elements(lambda x: str(x) if x is not None else None, return_dtype=pl.Utf8) for col in nested_cols
153+
])
154+
df.write_csv(filename)
132155
case 'json':
133156
df.write_json(filename)
134157
case _:

pyproject.toml

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[tool.poetry]
2+
name = "industryDocumentsWrapper"
3+
version = "0.111"
4+
description = "A simple python wrapper for the UCSF Industry Documents API."
5+
authors = ["Rolando Rodriguez <[email protected]>"]
6+
maintainers = ["Rolando Rodriguez <[email protected]>"]
7+
license = "Apache-2.0"
8+
readme = "README.md"
9+
packages = [{include = "industryDocumentsWrapper"}]
10+
repository = "https://github.com/UNC-Libraries/UCSF-Industry-Docs-API-Python-Wrapper"
11+
keywords = ["UCSF", "Industry Documents", "API", "JUUL"]
12+
13+
[tool.poetry.dependencies]
14+
python = "^3.12"
15+
polars = "^1.14.0"
16+
requests = "^2.32.3"
17+
18+
[tool.poetry.group.test.dependencies]
19+
pytest="^8.2.0"
20+
21+
[build-system]
22+
requires = ["poetry-core"]
23+
build-backend = "poetry.core.masonry.api"

tests/test_ucsf_api.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from unittest import mock
33
from industryDocumentsWrapper import IndustryDocsSearch
44

5+
## TO DO:
6+
## 1. Fix the mock response and and set tests to use mock response
57
# Mock the requests.get() response
68
@pytest.fixture
79
def mock_json_response():
@@ -57,10 +59,10 @@ def reset_results(indDocSearch):
5759
# Tests the IndustryDocsSearch methods
5860

5961
def test_create_query_with_q(indDocSearch):
60-
assert indDocSearch._create_query(q='collection:test AND industry:tobacco', wt='json', cursorMark='*', sort='id%20asc') == 'https://metadata.idl.ucsf.edu/solr/ltdl3/query?q=(collection:test AND industry:tobacco)&wt=json&cursorMark=*&sort=id%20asc'
62+
assert indDocSearch._create_query(q='collection:"test" AND industry:"tobacco"', wt='json', cursorMark='*', sort='id%20asc') == 'https://metadata.idl.ucsf.edu/solr/ltdl3/query?q=(collection:"test" AND industry:"tobacco")&wt=json&cursorMark=*&sort=id%20asc'
6163

6264
def test_create_query_without_q(indDocSearch):
63-
assert indDocSearch._create_query(q=False, collection='test', industry='tobacco', wt='json', cursorMark='*', sort='id%20asc') == 'https://metadata.idl.ucsf.edu/solr/ltdl3/query?q=(collection:test AND industry:tobacco)&wt=json&cursorMark=*&sort=id%20asc'
65+
assert indDocSearch._create_query(q=False, collection='test', industry='tobacco', wt='json', cursorMark='*', sort='id%20asc') == 'https://metadata.idl.ucsf.edu/solr/ltdl3/query?q=(collection:"test" AND industry:"tobacco")&wt=json&cursorMark=*&sort=id%20asc'
6466

6567
def test_update_cursormark(indDocSearch):
6668
query = 'https://metadata.idl.ucsf.edu/solr/ltdl3/query?q=(collection:test)&wt=json&cursorMark=*&sort=id%20asc'
@@ -101,14 +103,19 @@ def test_query_with_q_500(indDocSearch):
101103
assert indDocSearch.results[0]['url'] == 'https://www.industrydocuments.ucsf.edu/tobacco/docs/#id=ffbb0284'
102104

103105
def test_query_with_no_q_50(indDocSearch):
104-
indDocSearch.query(industry='tobacco', collection='JUUL labs Collection', case='State of North Carolina', doc_type='email', n=50)
106+
indDocSearch.query(industry='tobacco', collection='JUUL labs Collection', case='State of North Carolina', type='email', n=50)
105107
assert len(indDocSearch.results) == 50
106108
assert len(set([x['id'] for x in indDocSearch.results])) == 50
107109

108110
def test_query_with_no_q_1000(indDocSearch):
109-
indDocSearch.query(industry='tobacco', collection='JUUL labs Collection', case='State of North Carolina', doc_type='email', n=1000)
111+
indDocSearch.query(industry='tobacco', collection='JUUL labs Collection', case='State of North Carolina', type='email', n=1000)
110112
assert len(indDocSearch.results) == 1000
111113
assert len(set([x['id'] for x in indDocSearch.results])) == 1000
114+
115+
def test_query_with_no_q_50000(indDocSearch):
116+
indDocSearch.query(industry='tobacco', collection='JUUL labs Collection', case='State of North Carolina', type='email', n=50000)
117+
assert len(indDocSearch.results) == 50000
118+
assert len(set([x['id'] for x in indDocSearch.results])) == 50000
112119

113120
def test_save_parquet(indDocSearch, mock_results, tmp_path):
114121
indDocSearch.results = mock_results
@@ -120,12 +127,12 @@ def test_save_parquet(indDocSearch, mock_results, tmp_path):
120127
assert d.exists()
121128
assert d.stat().st_size > 0
122129

123-
# def test_save_csv(indDocSearch, mock_results, tmp_path):
124-
# indDocSearch.results = mock_results
125-
# d = tmp_path / 'test.csv'
126-
# indDocSearch.save(d, format='csv')
127-
# assert d.exists()
128-
# assert d.stat().st_size > 0
130+
def test_save_csv(indDocSearch, mock_results, tmp_path):
131+
indDocSearch.results = mock_results
132+
d = tmp_path / 'test.csv'
133+
indDocSearch.save(d, format='csv')
134+
assert d.exists()
135+
assert d.stat().st_size > 0
129136

130137
def test_save_json(indDocSearch, mock_results, tmp_path):
131138
indDocSearch.results = mock_results

0 commit comments

Comments
 (0)