11import pytest
2+ from typing import List
23from operator import itemgetter
34import itertools
45import numpy as np
1213import io
1314
1415
16+ @pytest .fixture
17+ def lang ():
18+ yield schema .Language ()
19+
20+
21+ @pytest .fixture
22+ def languages (lang ) -> List :
23+ og_contents = lang .contents
24+ languages = og_contents .copy ()
25+ yield languages
26+ lang .contents = og_contents
27+
28+
29+ @pytest .fixture
30+ def subject ():
31+ yield schema .Subject ()
32+
33+
34+
1535class TestFetch :
16- def test_getattribute (self , schema_any ):
36+ def test_getattribute (self , schema_any , subject ):
1737 """Testing Fetch.__call__ with attributes"""
18- subject = schema .Subject ()
1938 list1 = sorted (
2039 subject .proj ().fetch (as_dict = True ), key = itemgetter ("subject_id" )
2140 )
@@ -37,19 +56,15 @@ def test_getattribute(self, schema_any):
3756 for l1 , l2 in zip (list1 , list2 ):
3857 assert l1 == l2 , "Primary key is not returned correctly"
3958
40- def test_getattribute_for_fetch1 (self , schema_any ):
59+ def test_getattribute_for_fetch1 (self , schema_any , subject ):
4160 """Testing Fetch1.__call__ with attributes"""
42- subject = schema .Subject ()
4361 assert (subject & "subject_id=10" ).fetch1 ("subject_id" ) == 10
4462 assert (
4563 (subject & "subject_id=10" ).fetch1 ("subject_id" , "species" ) ==
4664 (10 , "monkey" ))
4765
48- def test_order_by (self , schema_any ):
66+ def test_order_by (self , schema_any , lang , languages ):
4967 """Tests order_by sorting order"""
50- lang = schema .Language ()
51- languages = schema .Language .contents
52-
5368 for ord_name , ord_lang in itertools .product (* 2 * [["ASC" , "DESC" ]]):
5469 cur = lang .fetch (order_by = ("name " + ord_name , "language " + ord_lang ))
5570 languages .sort (key = itemgetter (1 ), reverse = ord_lang == "DESC" )
@@ -60,38 +75,31 @@ def test_order_by(self, schema_any):
6075 "Sorting order is different" ,
6176 )
6277
63- def test_order_by_default (self , schema_any ):
78+ def test_order_by_default (self , schema_any , lang , languages ):
6479 """Tests order_by sorting order with defaults"""
65- lang = schema .Language ()
66- languages = schema .Language .contents
6780 cur = lang .fetch (order_by = ("language" , "name DESC" ))
6881 languages .sort (key = itemgetter (0 ), reverse = True )
6982 languages .sort (key = itemgetter (1 ), reverse = False )
7083 for c , l in zip (cur , languages ):
7184 assert np .all ([cc == ll for cc , ll in zip (c , l )]), "Sorting order is different"
7285
73- def test_limit (self , schema_any ):
86+ def test_limit (self , schema_any , lang ):
7487 """Test the limit kwarg"""
75- lang = schema .Language ()
7688 limit = 4
7789 cur = lang .fetch (limit = limit )
7890 assert len (cur ) == limit , "Length is not correct"
7991
80- def test_order_by_limit (self , schema_any ):
92+ def test_order_by_limit (self , schema_any , lang , languages ):
8193 """Test the combination of order by and limit kwargs"""
82- lang = schema .Language ()
83- languages = schema .Language .contents
84-
8594 cur = lang .fetch (limit = 4 , order_by = ["language" , "name DESC" ])
8695 languages .sort (key = itemgetter (0 ), reverse = True )
8796 languages .sort (key = itemgetter (1 ), reverse = False )
8897 assert len (cur ) == 4 , "Length is not correct"
8998 for c , l in list (zip (cur , languages ))[:4 ]:
9099 assert np .all ([cc == ll for cc , ll in zip (c , l )]), "Sorting order is different"
91100
92- @staticmethod
93101 def test_head_tail (self , schema_any ):
94- query = schema_any .User * schema .Language
102+ query = schema .User * schema .Language
95103 n = 5
96104 frame = query .head (n , format = "frame" )
97105 assert isinstance (frame , pandas .DataFrame )
@@ -107,27 +115,22 @@ def test_head_tail(self, schema_any):
107115 assert len (frame ) == n
108116 assert query .primary_key == frame .index .names
109117
110- def test_limit_offset (self , schema_any ):
118+ def test_limit_offset (self , schema_any , lang , languages ):
111119 """Test the limit and offset kwargs together"""
112- lang = schema .Language ()
113- languages = schema .Language .contents
114-
115120 cur = lang .fetch (offset = 2 , limit = 4 , order_by = ["language" , "name DESC" ])
116121 languages .sort (key = itemgetter (0 ), reverse = True )
117122 languages .sort (key = itemgetter (1 ), reverse = False )
118123 assert len (cur ) == 4 , "Length is not correct"
119124 for c , l in list (zip (cur , languages [2 :6 ])):
120125 assert np .all ([cc == ll for cc , ll in zip (c , l )]), "Sorting order is different"
121126
122- def test_iter (self , schema_any ):
127+ def test_iter (self , schema_any , lang , languages ):
123128 """Test iterator"""
124- lang = schema .Language ()
125- languages = schema .Language .contents
126129 cur = lang .fetch (order_by = ["language" , "name DESC" ])
127130 languages .sort (key = itemgetter (0 ), reverse = True )
128131 languages .sort (key = itemgetter (1 ), reverse = False )
129- for (name , lang ), (tname , tlang ) in list (zip (cur , languages )):
130- assert name == tname and lang == tlang , "Values are not the same"
132+ for (name , lang_val ), (tname , tlang ) in list (zip (cur , languages )):
133+ assert name == tname and lang_val == tlang , "Values are not the same"
131134 # now as dict
132135 cur = lang .fetch (as_dict = True , order_by = ("language" , "name DESC" ))
133136 for row , (tname , tlang ) in list (zip (cur , languages )):
@@ -136,30 +139,38 @@ def test_iter(self, schema_any):
136139 "Values are not the same" ,
137140 )
138141
139- def test_keys (self , schema_any ):
142+ def test_keys (self , schema_any , lang , languages ):
140143 """test key fetch"""
141- lang = schema .Language ()
142- languages = schema .Language .contents
143144 languages .sort (key = itemgetter (0 ), reverse = True )
144145 languages .sort (key = itemgetter (1 ), reverse = False )
145146
147+ lang = schema .Language ()
146148 cur = lang .fetch ("name" , "language" , order_by = ("language" , "name DESC" ))
147149 cur2 = list (lang .fetch ("KEY" , order_by = ["language" , "name DESC" ]))
148150
149151 for c , c2 in zip (zip (* cur ), cur2 ):
150152 assert c == tuple (c2 .values ()), "Values are not the same"
151153
152- def test_attributes_as_dict (self , schema_any ): # issue #595
153- subject = schema .Subject ()
154+ def test_attributes_as_dict (self , schema_any , subject ):
155+ """
156+ Issue #595
157+ """
154158 attrs = ("species" , "date_of_birth" )
155159 result = subject .fetch (* attrs , as_dict = True )
156160 assert bool (result ) and len (result ) == len (subject )
157161 assert set (result [0 ]) == set (attrs )
158162
159- def test_fetch1_step1 (self , schema_any ):
160- lang = schema .Language ()
163+ def test_fetch1_step1 (self , schema_any , lang , languages ):
164+ assert lang .contents == languages == [
165+ ("Fabian" , "English" ),
166+ ("Edgar" , "English" ),
167+ ("Dimitri" , "English" ),
168+ ("Dimitri" , "Ukrainian" ),
169+ ("Fabian" , "German" ),
170+ ("Edgar" , "Japanese" ),
171+ ], "Unexpected contents in Language table"
161172 key = {"name" : "Edgar" , "language" : "Japanese" }
162- true = schema . Language . contents [- 1 ]
173+ true = languages [- 1 ]
163174 dat = (lang & key ).fetch1 ()
164175 for k , (ke , c ) in zip (true , dat .items ()):
165176 assert k == c == (lang & key ).fetch1 (ke ), "Values are not the same"
@@ -168,43 +179,37 @@ def test_misspelled_attribute(self, schema_any):
168179 with pytest .raises (dj .DataJointError ):
169180 f = (schema .Language & 'lang = "ENGLISH"' ).fetch ()
170181
171- def test_repr (self , schema_any ):
182+ def test_repr (self , schema_any , subject ):
172183 """Test string representation of fetch, returning table preview"""
173- subject = schema .Subject ()
174184 repr = subject .fetch .__repr__ ()
175185 n = len (repr .strip ().split ("\n " ))
176186 limit = dj .config ["display.limit" ]
177187 # 3 lines are used for headers (2) and summary statement (1)
178188 assert n - 3 <= limit
179189
180- def test_fetch_none (self , schema_any ):
190+ def test_fetch_none (self , schema_any , lang ):
181191 """Test preparing attributes for getitem"""
182- lang = schema .Language ()
183192 with pytest .raises (dj .DataJointError ):
184193 lang .fetch (None )
185194
186- def test_asdict (self , schema_any ):
195+ def test_asdict (self , schema_any , lang ):
187196 """Test returns as dictionaries"""
188- lang = schema .Language ()
189197 d = lang .fetch (as_dict = True )
190198 for dd in d :
191199 assert isinstance (dd , dict )
192200
193- def test_offset (self , schema_any ):
201+ def test_offset (self , schema_any , lang , languages ):
194202 """Tests offset"""
195- lang = schema .Language ()
196203 cur = lang .fetch (limit = 4 , offset = 1 , order_by = ["language" , "name DESC" ])
197204
198- languages = lang .contents
199205 languages .sort (key = itemgetter (0 ), reverse = True )
200206 languages .sort (key = itemgetter (1 ), reverse = False )
201207 assert len (cur ) == 4 , "Length is not correct"
202208 for c , l in list (zip (cur , languages [1 :]))[:4 ]:
203209 assert np .all ([cc == ll for cc , ll in zip (c , l )]), "Sorting order is different"
204210
205- def test_limit_warning (self , schema_any ):
211+ def test_limit_warning (self , schema_any , lang ):
206212 """Tests whether warning is raised if offset is used without limit."""
207- lang = schema .Language ()
208213 logger = logging .getLogger ("datajoint" )
209214 log_capture = io .StringIO ()
210215 stream_handler = logging .StreamHandler (log_capture )
@@ -224,21 +229,17 @@ def test_limit_warning(self, schema_any):
224229 logger .removeHandler (handler )
225230 assert "[WARNING]: Offset set, but no limit." in log_contents
226231
227- def test_len (self , schema_any ):
232+ def test_len (self , schema_any , lang ):
228233 """Tests __len__"""
229- lang = schema .Language ()
230- assert (
231- len (lang .fetch ()) == len (lang )), "__len__ is not behaving properly"
234+ assert len (lang .fetch ()) == len (lang ), "__len__ is not behaving properly"
232235
233- def test_fetch1_step2 (self , schema_any ):
236+ def test_fetch1_step2 (self , schema_any , lang ):
234237 """Tests whether fetch1 raises error"""
235- lang = schema .Language ()
236238 with pytest .raises (dj .DataJointError ):
237239 lang .fetch1 ()
238240
239- def test_fetch1_step3 (self , schema_any ):
241+ def test_fetch1_step3 (self , schema_any , lang ):
240242 """Tests whether fetch1 raises error"""
241- lang = schema .Language ()
242243 with pytest .raises (dj .DataJointError ):
243244 lang .fetch1 ("name" )
244245
@@ -271,9 +272,8 @@ def test_nullable_numbers(self, schema_any):
271272 assert any (np .isnan (d ))
272273 assert any (np .isnan (f ))
273274
274- def test_fetch_format (self , schema_any ):
275+ def test_fetch_format (self , schema_any , subject ):
275276 """test fetch_format='frame'"""
276- subject = schema .Subject ()
277277 with dj .config (fetch_format = "frame" ):
278278 # test if lists are both dicts
279279 list1 = sorted (
@@ -300,9 +300,8 @@ def test_fetch_format(self, schema_any):
300300 for l1 , l2 in zip (list1 , list2 ):
301301 assert l1 == l2 , "Primary key is not returned correctly"
302302
303- def test_key_fetch1 (self , schema_any ):
303+ def test_key_fetch1 (self , schema_any , subject ):
304304 """test KEY fetch1 - issue #976"""
305- subject = schema .Subject ()
306305 with dj .config (fetch_format = "array" ):
307306 k1 = (subject & "subject_id=10" ).fetch1 ("KEY" )
308307 with dj .config (fetch_format = "frame" ):
0 commit comments