Skip to content

Commit e23f909

Browse files
committed
cp to tests
1 parent 7c3c2b7 commit e23f909

1 file changed

Lines changed: 390 additions & 0 deletions

File tree

tests/test_fetch.py

Lines changed: 390 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,390 @@
1+
from nose.tools import (
2+
assert_true,
3+
raises,
4+
assert_equal,
5+
assert_dict_equal,
6+
assert_list_equal,
7+
assert_set_equal,
8+
)
9+
from operator import itemgetter
10+
import itertools
11+
import numpy as np
12+
import decimal
13+
import pandas
14+
import warnings
15+
from . import schema
16+
from .schema import Parent, Stimulus
17+
import datajoint as dj
18+
import os
19+
import logging
20+
import io
21+
22+
logger = logging.getLogger("datajoint")
23+
24+
25+
class TestFetch:
26+
@classmethod
27+
def setup_class(cls):
28+
cls.subject = schema.Subject()
29+
cls.lang = schema.Language()
30+
31+
def test_getattribute(self):
32+
"""Testing Fetch.__call__ with attributes"""
33+
list1 = sorted(
34+
self.subject.proj().fetch(as_dict=True), key=itemgetter("subject_id")
35+
)
36+
list2 = sorted(self.subject.fetch(dj.key), key=itemgetter("subject_id"))
37+
for l1, l2 in zip(list1, list2):
38+
assert_dict_equal(l1, l2, "Primary key is not returned correctly")
39+
40+
tmp = self.subject.fetch(order_by="subject_id")
41+
42+
subject_notes, key, real_id = self.subject.fetch(
43+
"subject_notes", dj.key, "real_id"
44+
)
45+
46+
np.testing.assert_array_equal(
47+
sorted(subject_notes), sorted(tmp["subject_notes"])
48+
)
49+
np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"]))
50+
list1 = sorted(key, key=itemgetter("subject_id"))
51+
for l1, l2 in zip(list1, list2):
52+
assert_dict_equal(l1, l2, "Primary key is not returned correctly")
53+
54+
def test_getattribute_for_fetch1(self):
55+
"""Testing Fetch1.__call__ with attributes"""
56+
assert_true((self.subject & "subject_id=10").fetch1("subject_id") == 10)
57+
assert_equal(
58+
(self.subject & "subject_id=10").fetch1("subject_id", "species"),
59+
(10, "monkey"),
60+
)
61+
62+
def test_order_by(self):
63+
"""Tests order_by sorting order"""
64+
languages = schema.Language.contents
65+
66+
for ord_name, ord_lang in itertools.product(*2 * [["ASC", "DESC"]]):
67+
cur = self.lang.fetch(order_by=("name " + ord_name, "language " + ord_lang))
68+
languages.sort(key=itemgetter(1), reverse=ord_lang == "DESC")
69+
languages.sort(key=itemgetter(0), reverse=ord_name == "DESC")
70+
for c, l in zip(cur, languages):
71+
assert_true(
72+
np.all(cc == ll for cc, ll in zip(c, l)),
73+
"Sorting order is different",
74+
)
75+
76+
def test_order_by_default(self):
77+
"""Tests order_by sorting order with defaults"""
78+
languages = schema.Language.contents
79+
cur = self.lang.fetch(order_by=("language", "name DESC"))
80+
languages.sort(key=itemgetter(0), reverse=True)
81+
languages.sort(key=itemgetter(1), reverse=False)
82+
for c, l in zip(cur, languages):
83+
assert_true(
84+
np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
85+
)
86+
87+
def test_limit(self):
88+
"""Test the limit kwarg"""
89+
limit = 4
90+
cur = self.lang.fetch(limit=limit)
91+
assert_equal(len(cur), limit, "Length is not correct")
92+
93+
def test_order_by_limit(self):
94+
"""Test the combination of order by and limit kwargs"""
95+
languages = schema.Language.contents
96+
97+
cur = self.lang.fetch(limit=4, order_by=["language", "name DESC"])
98+
languages.sort(key=itemgetter(0), reverse=True)
99+
languages.sort(key=itemgetter(1), reverse=False)
100+
assert_equal(len(cur), 4, "Length is not correct")
101+
for c, l in list(zip(cur, languages))[:4]:
102+
assert_true(
103+
np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
104+
)
105+
106+
@staticmethod
107+
def test_head_tail():
108+
query = schema.User * schema.Language
109+
n = 5
110+
frame = query.head(n, format="frame")
111+
assert_true(isinstance(frame, pandas.DataFrame))
112+
array = query.head(n, format="array")
113+
assert_equal(array.size, n)
114+
assert_equal(len(frame), n)
115+
assert_list_equal(query.primary_key, frame.index.names)
116+
117+
n = 4
118+
frame = query.tail(n, format="frame")
119+
array = query.tail(n, format="array")
120+
assert_equal(array.size, n)
121+
assert_equal(len(frame), n)
122+
assert_list_equal(query.primary_key, frame.index.names)
123+
124+
def test_limit_offset(self):
125+
"""Test the limit and offset kwargs together"""
126+
languages = schema.Language.contents
127+
128+
cur = self.lang.fetch(offset=2, limit=4, order_by=["language", "name DESC"])
129+
languages.sort(key=itemgetter(0), reverse=True)
130+
languages.sort(key=itemgetter(1), reverse=False)
131+
assert_equal(len(cur), 4, "Length is not correct")
132+
for c, l in list(zip(cur, languages[2:6])):
133+
assert_true(
134+
np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
135+
)
136+
137+
def test_iter(self):
138+
"""Test iterator"""
139+
languages = schema.Language.contents
140+
cur = self.lang.fetch(order_by=["language", "name DESC"])
141+
languages.sort(key=itemgetter(0), reverse=True)
142+
languages.sort(key=itemgetter(1), reverse=False)
143+
for (name, lang), (tname, tlang) in list(zip(cur, languages)):
144+
assert_true(name == tname and lang == tlang, "Values are not the same")
145+
# now as dict
146+
cur = self.lang.fetch(as_dict=True, order_by=("language", "name DESC"))
147+
for row, (tname, tlang) in list(zip(cur, languages)):
148+
assert_true(
149+
row["name"] == tname and row["language"] == tlang,
150+
"Values are not the same",
151+
)
152+
153+
def test_keys(self):
154+
"""test key fetch"""
155+
languages = schema.Language.contents
156+
languages.sort(key=itemgetter(0), reverse=True)
157+
languages.sort(key=itemgetter(1), reverse=False)
158+
159+
cur = self.lang.fetch("name", "language", order_by=("language", "name DESC"))
160+
cur2 = list(self.lang.fetch("KEY", order_by=["language", "name DESC"]))
161+
162+
for c, c2 in zip(zip(*cur), cur2):
163+
assert_true(c == tuple(c2.values()), "Values are not the same")
164+
165+
def test_attributes_as_dict(self): # issue #595
166+
attrs = ("species", "date_of_birth")
167+
result = self.subject.fetch(*attrs, as_dict=True)
168+
assert_true(bool(result) and len(result) == len(self.subject))
169+
assert_set_equal(set(result[0]), set(attrs))
170+
171+
def test_fetch1_step1(self):
172+
key = {"name": "Edgar", "language": "Japanese"}
173+
true = schema.Language.contents[-1]
174+
dat = (self.lang & key).fetch1()
175+
for k, (ke, c) in zip(true, dat.items()):
176+
assert_true(
177+
k == c == (self.lang & key).fetch1(ke), "Values are not the same"
178+
)
179+
180+
@raises(dj.DataJointError)
181+
def test_misspelled_attribute(self):
182+
f = (schema.Language & 'lang = "ENGLISH"').fetch()
183+
184+
def test_repr(self):
185+
"""Test string representation of fetch, returning table preview"""
186+
repr = self.subject.fetch.__repr__()
187+
n = len(repr.strip().split("\n"))
188+
limit = dj.config["display.limit"]
189+
# 3 lines are used for headers (2) and summary statement (1)
190+
assert_true(n - 3 <= limit)
191+
192+
@raises(dj.DataJointError)
193+
def test_fetch_none(self):
194+
"""Test preparing attributes for getitem"""
195+
self.lang.fetch(None)
196+
197+
def test_asdict(self):
198+
"""Test returns as dictionaries"""
199+
d = self.lang.fetch(as_dict=True)
200+
for dd in d:
201+
assert_true(isinstance(dd, dict))
202+
203+
def test_offset(self):
204+
"""Tests offset"""
205+
cur = self.lang.fetch(limit=4, offset=1, order_by=["language", "name DESC"])
206+
207+
languages = self.lang.contents
208+
languages.sort(key=itemgetter(0), reverse=True)
209+
languages.sort(key=itemgetter(1), reverse=False)
210+
assert_equal(len(cur), 4, "Length is not correct")
211+
for c, l in list(zip(cur, languages[1:]))[:4]:
212+
assert_true(
213+
np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different"
214+
)
215+
216+
def test_limit_warning(self):
217+
"""Tests whether warning is raised if offset is used without limit."""
218+
log_capture = io.StringIO()
219+
stream_handler = logging.StreamHandler(log_capture)
220+
log_format = logging.Formatter(
221+
"[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s"
222+
)
223+
stream_handler.setFormatter(log_format)
224+
stream_handler.set_name("test_limit_warning")
225+
logger.addHandler(stream_handler)
226+
self.lang.fetch(offset=1)
227+
228+
log_contents = log_capture.getvalue()
229+
log_capture.close()
230+
231+
for handler in logger.handlers: # Clean up handler
232+
if handler.name == "test_limit_warning":
233+
logger.removeHandler(handler)
234+
assert "[WARNING]: Offset set, but no limit." in log_contents
235+
236+
def test_len(self):
237+
"""Tests __len__"""
238+
assert_equal(
239+
len(self.lang.fetch()), len(self.lang), "__len__ is not behaving properly"
240+
)
241+
242+
@raises(dj.DataJointError)
243+
def test_fetch1_step2(self):
244+
"""Tests whether fetch1 raises error"""
245+
self.lang.fetch1()
246+
247+
@raises(dj.DataJointError)
248+
def test_fetch1_step3(self):
249+
"""Tests whether fetch1 raises error"""
250+
self.lang.fetch1("name")
251+
252+
def test_decimal(self):
253+
"""Tests that decimal fields are correctly fetched and used in restrictions, see issue #334"""
254+
rel = schema.DecimalPrimaryKey()
255+
rel.insert1([decimal.Decimal("3.1415926")])
256+
keys = rel.fetch()
257+
assert_true(len(rel & keys[0]) == 1)
258+
keys = rel.fetch(dj.key)
259+
assert_true(len(rel & keys[1]) == 1)
260+
261+
def test_nullable_numbers(self):
262+
"""test mixture of values and nulls in numeric attributes"""
263+
table = schema.NullableNumbers()
264+
table.insert(
265+
(
266+
(
267+
k,
268+
np.random.randn(),
269+
np.random.randint(-1000, 1000),
270+
np.random.randn(),
271+
)
272+
for k in range(10)
273+
)
274+
)
275+
table.insert1((100, None, None, None))
276+
f, d, i = table.fetch("fvalue", "dvalue", "ivalue")
277+
assert_true(None in i)
278+
assert_true(any(np.isnan(d)))
279+
assert_true(any(np.isnan(f)))
280+
281+
def test_fetch_format(self):
282+
"""test fetch_format='frame'"""
283+
with dj.config(fetch_format="frame"):
284+
# test if lists are both dicts
285+
list1 = sorted(
286+
self.subject.proj().fetch(as_dict=True), key=itemgetter("subject_id")
287+
)
288+
list2 = sorted(self.subject.fetch(dj.key), key=itemgetter("subject_id"))
289+
for l1, l2 in zip(list1, list2):
290+
assert_dict_equal(l1, l2, "Primary key is not returned correctly")
291+
292+
# tests if pandas dataframe
293+
tmp = self.subject.fetch(order_by="subject_id")
294+
assert_true(isinstance(tmp, pandas.DataFrame))
295+
tmp = tmp.to_records()
296+
297+
subject_notes, key, real_id = self.subject.fetch(
298+
"subject_notes", dj.key, "real_id"
299+
)
300+
301+
np.testing.assert_array_equal(
302+
sorted(subject_notes), sorted(tmp["subject_notes"])
303+
)
304+
np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"]))
305+
list1 = sorted(key, key=itemgetter("subject_id"))
306+
for l1, l2 in zip(list1, list2):
307+
assert_dict_equal(l1, l2, "Primary key is not returned correctly")
308+
309+
def test_key_fetch1(self):
310+
"""test KEY fetch1 - issue #976"""
311+
with dj.config(fetch_format="array"):
312+
k1 = (self.subject & "subject_id=10").fetch1("KEY")
313+
with dj.config(fetch_format="frame"):
314+
k2 = (self.subject & "subject_id=10").fetch1("KEY")
315+
assert_equal(k1, k2)
316+
317+
def test_same_secondary_attribute(self):
318+
children = (schema.Child * schema.Parent().proj()).fetch()["name"]
319+
assert len(children) == 1
320+
assert children[0] == "Dan"
321+
322+
def test_query_caching(self):
323+
# initialize cache directory
324+
os.mkdir(os.path.expanduser("~/dj_query_cache"))
325+
326+
with dj.config(query_cache=os.path.expanduser("~/dj_query_cache")):
327+
conn = schema.TTest3.connection
328+
# insert sample data and load cache
329+
schema.TTest3.insert([dict(key=100 + i, value=200 + i) for i in range(2)])
330+
conn.set_query_cache(query_cache="main")
331+
cached_res = schema.TTest3().fetch()
332+
# attempt to insert while caching enabled
333+
try:
334+
schema.TTest3.insert(
335+
[dict(key=200 + i, value=400 + i) for i in range(2)]
336+
)
337+
assert False, "Insert allowed while query caching enabled"
338+
except dj.DataJointError:
339+
conn.set_query_cache()
340+
# insert new data
341+
schema.TTest3.insert([dict(key=600 + i, value=800 + i) for i in range(2)])
342+
# re-enable cache to access old results
343+
conn.set_query_cache(query_cache="main")
344+
previous_cache = schema.TTest3().fetch()
345+
# verify properly cached and how to refresh results
346+
assert all([c == p for c, p in zip(cached_res, previous_cache)])
347+
conn.set_query_cache()
348+
uncached_res = schema.TTest3().fetch()
349+
assert len(uncached_res) > len(cached_res)
350+
# purge query cache
351+
conn.purge_query_cache()
352+
353+
# reset cache directory state (will fail if purge was unsuccessful)
354+
os.rmdir(os.path.expanduser("~/dj_query_cache"))
355+
356+
def test_fetch_group_by(self):
357+
# https://github.com/datajoint/datajoint-python/issues/914
358+
359+
assert Parent().fetch("KEY", order_by="name") == [{"parent_id": 1}]
360+
361+
def test_dj_u_distinct(self):
362+
# Test developed to see if removing DISTINCT from the select statement
363+
# generation breaks the dj.U universal set implementation
364+
365+
# Contents to be inserted
366+
contents = [(1, 2, 3), (2, 2, 3), (3, 3, 2), (4, 5, 5)]
367+
Stimulus.insert(contents)
368+
369+
# Query the whole table
370+
test_query = Stimulus()
371+
372+
# Use dj.U to create a list of unique contrast and brightness combinations
373+
result = dj.U("contrast", "brightness") & test_query
374+
expected_result = [
375+
{"contrast": 2, "brightness": 3},
376+
{"contrast": 3, "brightness": 2},
377+
{"contrast": 5, "brightness": 5},
378+
]
379+
380+
fetched_result = result.fetch(as_dict=True, order_by=("contrast", "brightness"))
381+
Stimulus.delete_quick()
382+
assert fetched_result == expected_result
383+
384+
def test_backslash(self):
385+
# https://github.com/datajoint/datajoint-python/issues/999
386+
expected = "She\Hulk"
387+
Parent.insert([(2, expected)])
388+
q = Parent & dict(name=expected)
389+
assert q.fetch1("name") == expected
390+
q.delete()

0 commit comments

Comments
 (0)