|
| 1 | +from nose.tools import ( |
| 2 | + assert_true, |
| 3 | + raises, |
| 4 | + assert_equal, |
| 5 | + assert_dict_equal, |
| 6 | + assert_list_equal, |
| 7 | + assert_set_equal, |
| 8 | +) |
| 9 | +from operator import itemgetter |
| 10 | +import itertools |
| 11 | +import numpy as np |
| 12 | +import decimal |
| 13 | +import pandas |
| 14 | +import warnings |
| 15 | +from . import schema |
| 16 | +from .schema import Parent, Stimulus |
| 17 | +import datajoint as dj |
| 18 | +import os |
| 19 | +import logging |
| 20 | +import io |
| 21 | + |
| 22 | +logger = logging.getLogger("datajoint") |
| 23 | + |
| 24 | + |
| 25 | +class TestFetch: |
| 26 | + @classmethod |
| 27 | + def setup_class(cls): |
| 28 | + cls.subject = schema.Subject() |
| 29 | + cls.lang = schema.Language() |
| 30 | + |
| 31 | + def test_getattribute(self): |
| 32 | + """Testing Fetch.__call__ with attributes""" |
| 33 | + list1 = sorted( |
| 34 | + self.subject.proj().fetch(as_dict=True), key=itemgetter("subject_id") |
| 35 | + ) |
| 36 | + list2 = sorted(self.subject.fetch(dj.key), key=itemgetter("subject_id")) |
| 37 | + for l1, l2 in zip(list1, list2): |
| 38 | + assert_dict_equal(l1, l2, "Primary key is not returned correctly") |
| 39 | + |
| 40 | + tmp = self.subject.fetch(order_by="subject_id") |
| 41 | + |
| 42 | + subject_notes, key, real_id = self.subject.fetch( |
| 43 | + "subject_notes", dj.key, "real_id" |
| 44 | + ) |
| 45 | + |
| 46 | + np.testing.assert_array_equal( |
| 47 | + sorted(subject_notes), sorted(tmp["subject_notes"]) |
| 48 | + ) |
| 49 | + np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"])) |
| 50 | + list1 = sorted(key, key=itemgetter("subject_id")) |
| 51 | + for l1, l2 in zip(list1, list2): |
| 52 | + assert_dict_equal(l1, l2, "Primary key is not returned correctly") |
| 53 | + |
| 54 | + def test_getattribute_for_fetch1(self): |
| 55 | + """Testing Fetch1.__call__ with attributes""" |
| 56 | + assert_true((self.subject & "subject_id=10").fetch1("subject_id") == 10) |
| 57 | + assert_equal( |
| 58 | + (self.subject & "subject_id=10").fetch1("subject_id", "species"), |
| 59 | + (10, "monkey"), |
| 60 | + ) |
| 61 | + |
| 62 | + def test_order_by(self): |
| 63 | + """Tests order_by sorting order""" |
| 64 | + languages = schema.Language.contents |
| 65 | + |
| 66 | + for ord_name, ord_lang in itertools.product(*2 * [["ASC", "DESC"]]): |
| 67 | + cur = self.lang.fetch(order_by=("name " + ord_name, "language " + ord_lang)) |
| 68 | + languages.sort(key=itemgetter(1), reverse=ord_lang == "DESC") |
| 69 | + languages.sort(key=itemgetter(0), reverse=ord_name == "DESC") |
| 70 | + for c, l in zip(cur, languages): |
| 71 | + assert_true( |
| 72 | + np.all(cc == ll for cc, ll in zip(c, l)), |
| 73 | + "Sorting order is different", |
| 74 | + ) |
| 75 | + |
| 76 | + def test_order_by_default(self): |
| 77 | + """Tests order_by sorting order with defaults""" |
| 78 | + languages = schema.Language.contents |
| 79 | + cur = self.lang.fetch(order_by=("language", "name DESC")) |
| 80 | + languages.sort(key=itemgetter(0), reverse=True) |
| 81 | + languages.sort(key=itemgetter(1), reverse=False) |
| 82 | + for c, l in zip(cur, languages): |
| 83 | + assert_true( |
| 84 | + np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" |
| 85 | + ) |
| 86 | + |
| 87 | + def test_limit(self): |
| 88 | + """Test the limit kwarg""" |
| 89 | + limit = 4 |
| 90 | + cur = self.lang.fetch(limit=limit) |
| 91 | + assert_equal(len(cur), limit, "Length is not correct") |
| 92 | + |
| 93 | + def test_order_by_limit(self): |
| 94 | + """Test the combination of order by and limit kwargs""" |
| 95 | + languages = schema.Language.contents |
| 96 | + |
| 97 | + cur = self.lang.fetch(limit=4, order_by=["language", "name DESC"]) |
| 98 | + languages.sort(key=itemgetter(0), reverse=True) |
| 99 | + languages.sort(key=itemgetter(1), reverse=False) |
| 100 | + assert_equal(len(cur), 4, "Length is not correct") |
| 101 | + for c, l in list(zip(cur, languages))[:4]: |
| 102 | + assert_true( |
| 103 | + np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" |
| 104 | + ) |
| 105 | + |
| 106 | + @staticmethod |
| 107 | + def test_head_tail(): |
| 108 | + query = schema.User * schema.Language |
| 109 | + n = 5 |
| 110 | + frame = query.head(n, format="frame") |
| 111 | + assert_true(isinstance(frame, pandas.DataFrame)) |
| 112 | + array = query.head(n, format="array") |
| 113 | + assert_equal(array.size, n) |
| 114 | + assert_equal(len(frame), n) |
| 115 | + assert_list_equal(query.primary_key, frame.index.names) |
| 116 | + |
| 117 | + n = 4 |
| 118 | + frame = query.tail(n, format="frame") |
| 119 | + array = query.tail(n, format="array") |
| 120 | + assert_equal(array.size, n) |
| 121 | + assert_equal(len(frame), n) |
| 122 | + assert_list_equal(query.primary_key, frame.index.names) |
| 123 | + |
| 124 | + def test_limit_offset(self): |
| 125 | + """Test the limit and offset kwargs together""" |
| 126 | + languages = schema.Language.contents |
| 127 | + |
| 128 | + cur = self.lang.fetch(offset=2, limit=4, order_by=["language", "name DESC"]) |
| 129 | + languages.sort(key=itemgetter(0), reverse=True) |
| 130 | + languages.sort(key=itemgetter(1), reverse=False) |
| 131 | + assert_equal(len(cur), 4, "Length is not correct") |
| 132 | + for c, l in list(zip(cur, languages[2:6])): |
| 133 | + assert_true( |
| 134 | + np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" |
| 135 | + ) |
| 136 | + |
| 137 | + def test_iter(self): |
| 138 | + """Test iterator""" |
| 139 | + languages = schema.Language.contents |
| 140 | + cur = self.lang.fetch(order_by=["language", "name DESC"]) |
| 141 | + languages.sort(key=itemgetter(0), reverse=True) |
| 142 | + languages.sort(key=itemgetter(1), reverse=False) |
| 143 | + for (name, lang), (tname, tlang) in list(zip(cur, languages)): |
| 144 | + assert_true(name == tname and lang == tlang, "Values are not the same") |
| 145 | + # now as dict |
| 146 | + cur = self.lang.fetch(as_dict=True, order_by=("language", "name DESC")) |
| 147 | + for row, (tname, tlang) in list(zip(cur, languages)): |
| 148 | + assert_true( |
| 149 | + row["name"] == tname and row["language"] == tlang, |
| 150 | + "Values are not the same", |
| 151 | + ) |
| 152 | + |
| 153 | + def test_keys(self): |
| 154 | + """test key fetch""" |
| 155 | + languages = schema.Language.contents |
| 156 | + languages.sort(key=itemgetter(0), reverse=True) |
| 157 | + languages.sort(key=itemgetter(1), reverse=False) |
| 158 | + |
| 159 | + cur = self.lang.fetch("name", "language", order_by=("language", "name DESC")) |
| 160 | + cur2 = list(self.lang.fetch("KEY", order_by=["language", "name DESC"])) |
| 161 | + |
| 162 | + for c, c2 in zip(zip(*cur), cur2): |
| 163 | + assert_true(c == tuple(c2.values()), "Values are not the same") |
| 164 | + |
| 165 | + def test_attributes_as_dict(self): # issue #595 |
| 166 | + attrs = ("species", "date_of_birth") |
| 167 | + result = self.subject.fetch(*attrs, as_dict=True) |
| 168 | + assert_true(bool(result) and len(result) == len(self.subject)) |
| 169 | + assert_set_equal(set(result[0]), set(attrs)) |
| 170 | + |
| 171 | + def test_fetch1_step1(self): |
| 172 | + key = {"name": "Edgar", "language": "Japanese"} |
| 173 | + true = schema.Language.contents[-1] |
| 174 | + dat = (self.lang & key).fetch1() |
| 175 | + for k, (ke, c) in zip(true, dat.items()): |
| 176 | + assert_true( |
| 177 | + k == c == (self.lang & key).fetch1(ke), "Values are not the same" |
| 178 | + ) |
| 179 | + |
| 180 | + @raises(dj.DataJointError) |
| 181 | + def test_misspelled_attribute(self): |
| 182 | + f = (schema.Language & 'lang = "ENGLISH"').fetch() |
| 183 | + |
| 184 | + def test_repr(self): |
| 185 | + """Test string representation of fetch, returning table preview""" |
| 186 | + repr = self.subject.fetch.__repr__() |
| 187 | + n = len(repr.strip().split("\n")) |
| 188 | + limit = dj.config["display.limit"] |
| 189 | + # 3 lines are used for headers (2) and summary statement (1) |
| 190 | + assert_true(n - 3 <= limit) |
| 191 | + |
| 192 | + @raises(dj.DataJointError) |
| 193 | + def test_fetch_none(self): |
| 194 | + """Test preparing attributes for getitem""" |
| 195 | + self.lang.fetch(None) |
| 196 | + |
| 197 | + def test_asdict(self): |
| 198 | + """Test returns as dictionaries""" |
| 199 | + d = self.lang.fetch(as_dict=True) |
| 200 | + for dd in d: |
| 201 | + assert_true(isinstance(dd, dict)) |
| 202 | + |
| 203 | + def test_offset(self): |
| 204 | + """Tests offset""" |
| 205 | + cur = self.lang.fetch(limit=4, offset=1, order_by=["language", "name DESC"]) |
| 206 | + |
| 207 | + languages = self.lang.contents |
| 208 | + languages.sort(key=itemgetter(0), reverse=True) |
| 209 | + languages.sort(key=itemgetter(1), reverse=False) |
| 210 | + assert_equal(len(cur), 4, "Length is not correct") |
| 211 | + for c, l in list(zip(cur, languages[1:]))[:4]: |
| 212 | + assert_true( |
| 213 | + np.all([cc == ll for cc, ll in zip(c, l)]), "Sorting order is different" |
| 214 | + ) |
| 215 | + |
| 216 | + def test_limit_warning(self): |
| 217 | + """Tests whether warning is raised if offset is used without limit.""" |
| 218 | + log_capture = io.StringIO() |
| 219 | + stream_handler = logging.StreamHandler(log_capture) |
| 220 | + log_format = logging.Formatter( |
| 221 | + "[%(asctime)s][%(funcName)s][%(levelname)s]: %(message)s" |
| 222 | + ) |
| 223 | + stream_handler.setFormatter(log_format) |
| 224 | + stream_handler.set_name("test_limit_warning") |
| 225 | + logger.addHandler(stream_handler) |
| 226 | + self.lang.fetch(offset=1) |
| 227 | + |
| 228 | + log_contents = log_capture.getvalue() |
| 229 | + log_capture.close() |
| 230 | + |
| 231 | + for handler in logger.handlers: # Clean up handler |
| 232 | + if handler.name == "test_limit_warning": |
| 233 | + logger.removeHandler(handler) |
| 234 | + assert "[WARNING]: Offset set, but no limit." in log_contents |
| 235 | + |
| 236 | + def test_len(self): |
| 237 | + """Tests __len__""" |
| 238 | + assert_equal( |
| 239 | + len(self.lang.fetch()), len(self.lang), "__len__ is not behaving properly" |
| 240 | + ) |
| 241 | + |
| 242 | + @raises(dj.DataJointError) |
| 243 | + def test_fetch1_step2(self): |
| 244 | + """Tests whether fetch1 raises error""" |
| 245 | + self.lang.fetch1() |
| 246 | + |
| 247 | + @raises(dj.DataJointError) |
| 248 | + def test_fetch1_step3(self): |
| 249 | + """Tests whether fetch1 raises error""" |
| 250 | + self.lang.fetch1("name") |
| 251 | + |
| 252 | + def test_decimal(self): |
| 253 | + """Tests that decimal fields are correctly fetched and used in restrictions, see issue #334""" |
| 254 | + rel = schema.DecimalPrimaryKey() |
| 255 | + rel.insert1([decimal.Decimal("3.1415926")]) |
| 256 | + keys = rel.fetch() |
| 257 | + assert_true(len(rel & keys[0]) == 1) |
| 258 | + keys = rel.fetch(dj.key) |
| 259 | + assert_true(len(rel & keys[1]) == 1) |
| 260 | + |
| 261 | + def test_nullable_numbers(self): |
| 262 | + """test mixture of values and nulls in numeric attributes""" |
| 263 | + table = schema.NullableNumbers() |
| 264 | + table.insert( |
| 265 | + ( |
| 266 | + ( |
| 267 | + k, |
| 268 | + np.random.randn(), |
| 269 | + np.random.randint(-1000, 1000), |
| 270 | + np.random.randn(), |
| 271 | + ) |
| 272 | + for k in range(10) |
| 273 | + ) |
| 274 | + ) |
| 275 | + table.insert1((100, None, None, None)) |
| 276 | + f, d, i = table.fetch("fvalue", "dvalue", "ivalue") |
| 277 | + assert_true(None in i) |
| 278 | + assert_true(any(np.isnan(d))) |
| 279 | + assert_true(any(np.isnan(f))) |
| 280 | + |
| 281 | + def test_fetch_format(self): |
| 282 | + """test fetch_format='frame'""" |
| 283 | + with dj.config(fetch_format="frame"): |
| 284 | + # test if lists are both dicts |
| 285 | + list1 = sorted( |
| 286 | + self.subject.proj().fetch(as_dict=True), key=itemgetter("subject_id") |
| 287 | + ) |
| 288 | + list2 = sorted(self.subject.fetch(dj.key), key=itemgetter("subject_id")) |
| 289 | + for l1, l2 in zip(list1, list2): |
| 290 | + assert_dict_equal(l1, l2, "Primary key is not returned correctly") |
| 291 | + |
| 292 | + # tests if pandas dataframe |
| 293 | + tmp = self.subject.fetch(order_by="subject_id") |
| 294 | + assert_true(isinstance(tmp, pandas.DataFrame)) |
| 295 | + tmp = tmp.to_records() |
| 296 | + |
| 297 | + subject_notes, key, real_id = self.subject.fetch( |
| 298 | + "subject_notes", dj.key, "real_id" |
| 299 | + ) |
| 300 | + |
| 301 | + np.testing.assert_array_equal( |
| 302 | + sorted(subject_notes), sorted(tmp["subject_notes"]) |
| 303 | + ) |
| 304 | + np.testing.assert_array_equal(sorted(real_id), sorted(tmp["real_id"])) |
| 305 | + list1 = sorted(key, key=itemgetter("subject_id")) |
| 306 | + for l1, l2 in zip(list1, list2): |
| 307 | + assert_dict_equal(l1, l2, "Primary key is not returned correctly") |
| 308 | + |
| 309 | + def test_key_fetch1(self): |
| 310 | + """test KEY fetch1 - issue #976""" |
| 311 | + with dj.config(fetch_format="array"): |
| 312 | + k1 = (self.subject & "subject_id=10").fetch1("KEY") |
| 313 | + with dj.config(fetch_format="frame"): |
| 314 | + k2 = (self.subject & "subject_id=10").fetch1("KEY") |
| 315 | + assert_equal(k1, k2) |
| 316 | + |
| 317 | + def test_same_secondary_attribute(self): |
| 318 | + children = (schema.Child * schema.Parent().proj()).fetch()["name"] |
| 319 | + assert len(children) == 1 |
| 320 | + assert children[0] == "Dan" |
| 321 | + |
| 322 | + def test_query_caching(self): |
| 323 | + # initialize cache directory |
| 324 | + os.mkdir(os.path.expanduser("~/dj_query_cache")) |
| 325 | + |
| 326 | + with dj.config(query_cache=os.path.expanduser("~/dj_query_cache")): |
| 327 | + conn = schema.TTest3.connection |
| 328 | + # insert sample data and load cache |
| 329 | + schema.TTest3.insert([dict(key=100 + i, value=200 + i) for i in range(2)]) |
| 330 | + conn.set_query_cache(query_cache="main") |
| 331 | + cached_res = schema.TTest3().fetch() |
| 332 | + # attempt to insert while caching enabled |
| 333 | + try: |
| 334 | + schema.TTest3.insert( |
| 335 | + [dict(key=200 + i, value=400 + i) for i in range(2)] |
| 336 | + ) |
| 337 | + assert False, "Insert allowed while query caching enabled" |
| 338 | + except dj.DataJointError: |
| 339 | + conn.set_query_cache() |
| 340 | + # insert new data |
| 341 | + schema.TTest3.insert([dict(key=600 + i, value=800 + i) for i in range(2)]) |
| 342 | + # re-enable cache to access old results |
| 343 | + conn.set_query_cache(query_cache="main") |
| 344 | + previous_cache = schema.TTest3().fetch() |
| 345 | + # verify properly cached and how to refresh results |
| 346 | + assert all([c == p for c, p in zip(cached_res, previous_cache)]) |
| 347 | + conn.set_query_cache() |
| 348 | + uncached_res = schema.TTest3().fetch() |
| 349 | + assert len(uncached_res) > len(cached_res) |
| 350 | + # purge query cache |
| 351 | + conn.purge_query_cache() |
| 352 | + |
| 353 | + # reset cache directory state (will fail if purge was unsuccessful) |
| 354 | + os.rmdir(os.path.expanduser("~/dj_query_cache")) |
| 355 | + |
| 356 | + def test_fetch_group_by(self): |
| 357 | + # https://github.com/datajoint/datajoint-python/issues/914 |
| 358 | + |
| 359 | + assert Parent().fetch("KEY", order_by="name") == [{"parent_id": 1}] |
| 360 | + |
| 361 | + def test_dj_u_distinct(self): |
| 362 | + # Test developed to see if removing DISTINCT from the select statement |
| 363 | + # generation breaks the dj.U universal set implementation |
| 364 | + |
| 365 | + # Contents to be inserted |
| 366 | + contents = [(1, 2, 3), (2, 2, 3), (3, 3, 2), (4, 5, 5)] |
| 367 | + Stimulus.insert(contents) |
| 368 | + |
| 369 | + # Query the whole table |
| 370 | + test_query = Stimulus() |
| 371 | + |
| 372 | + # Use dj.U to create a list of unique contrast and brightness combinations |
| 373 | + result = dj.U("contrast", "brightness") & test_query |
| 374 | + expected_result = [ |
| 375 | + {"contrast": 2, "brightness": 3}, |
| 376 | + {"contrast": 3, "brightness": 2}, |
| 377 | + {"contrast": 5, "brightness": 5}, |
| 378 | + ] |
| 379 | + |
| 380 | + fetched_result = result.fetch(as_dict=True, order_by=("contrast", "brightness")) |
| 381 | + Stimulus.delete_quick() |
| 382 | + assert fetched_result == expected_result |
| 383 | + |
| 384 | + def test_backslash(self): |
| 385 | + # https://github.com/datajoint/datajoint-python/issues/999 |
| 386 | + expected = "She\Hulk" |
| 387 | + Parent.insert([(2, expected)]) |
| 388 | + q = Parent & dict(name=expected) |
| 389 | + assert q.fetch1("name") == expected |
| 390 | + q.delete() |
0 commit comments