1- from nose . tools import assert_true , assert_false , assert_equal , raises
1+ import pytest
22import os
33import numpy as np
44from pathlib import Path
55import tempfile
66import datajoint as dj
7- from . import PREFIX , CONN_INFO
7+ from . import PREFIX
88from datajoint import DataJointError
99
10- schema = dj .Schema (PREFIX + "_update1" , connection = dj .conn (** CONN_INFO ))
1110
12- dj .config ["stores" ]["update_store" ] = dict (protocol = "file" , location = tempfile .mkdtemp ())
13-
14- dj .config ["stores" ]["update_repo" ] = dict (
15- stage = tempfile .mkdtemp (), protocol = "file" , location = tempfile .mkdtemp ()
16- )
17-
18-
19- scratch_folder = tempfile .mkdtemp ()
20-
21- dj .errors ._switch_filepath_types (True )
22-
23-
24- @schema
2511class Thing (dj .Manual ):
2612 definition = """
2713 thing : int
@@ -35,10 +21,38 @@ class Thing(dj.Manual):
3521 """
3622
3723
38- def test_update1 ():
39- """test normal updates"""
24+ @pytest .fixture (scope = "module" )
25+ def mock_stores_update (tmpdir_factory ):
26+ og_stores_config = dj .config .get ("stores" )
27+ if "stores" not in dj .config :
28+ dj .config ["stores" ] = {}
29+ dj .config ["stores" ]["update_store" ] = dict (
30+ protocol = "file" , location = tmpdir_factory .mktemp ("store" )
31+ )
32+ dj .config ["stores" ]["update_repo" ] = dict (
33+ stage = tmpdir_factory .mktemp ("repo_stage" ),
34+ protocol = "file" ,
35+ location = tmpdir_factory .mktemp ("repo_loc" ),
36+ )
37+ yield
38+ if og_stores_config is None :
39+ del dj .config ["stores" ]
40+ else :
41+ dj .config ["stores" ] = og_stores_config
4042
41- dj .errors ._switch_filepath_types (True )
43+
44+ @pytest .fixture
45+ def schema_update1 (connection_test ):
46+ schema = dj .Schema (
47+ PREFIX + "_update1" , context = dict (Thing = Thing ), connection = connection_test
48+ )
49+ schema (Thing )
50+ yield schema
51+ schema .drop ()
52+
53+
54+ def test_update1 (tmpdir , enable_filepath_feature , schema_update1 , mock_stores_update ):
55+ """Test normal updates"""
4256 # CHECK 1 -- initial insert
4357 key = dict (thing = 1 )
4458 Thing .insert1 (dict (key , frac = 0.5 ))
@@ -48,7 +62,7 @@ def test_update1():
4862 # numbers and datetimes
4963 Thing .update1 (dict (key , number = 3 , frac = 30 , timestamp = "2020-01-01 10:00:00" ))
5064 # attachment
51- attach_file = Path (scratch_folder , "attach1.dat" )
65+ attach_file = Path (tmpdir , "attach1.dat" )
5266 buffer1 = os .urandom (100 )
5367 attach_file .write_bytes (buffer1 )
5468 Thing .update1 (dict (key , picture = attach_file ))
@@ -67,7 +81,7 @@ def test_update1():
6781 managed_file .unlink ()
6882 assert not managed_file .is_file ()
6983
70- check2 = Thing .fetch1 (download_path = scratch_folder )
84+ check2 = Thing .fetch1 (download_path = tmpdir )
7185 buffer2 = Path (check2 ["picture" ]).read_bytes () # read attachment
7286 final_file_data = managed_file .read_bytes () # read filepath
7387
@@ -84,37 +98,50 @@ def test_update1():
8498 )
8599 check3 = Thing .fetch1 ()
86100
87- assert check1 ["number" ] == 0 and check1 ["picture" ] is None and check1 ["params" ] is None
101+ assert (
102+ check1 ["number" ] == 0 and check1 ["picture" ] is None and check1 ["params" ] is None
103+ )
88104
89- assert (check2 ["number" ] == 3
105+ assert (
106+ check2 ["number" ] == 3
90107 and check2 ["frac" ] == 30.0
91108 and check2 ["picture" ] is not None
92109 and check2 ["params" ] is None
93- and buffer1 == buffer2 )
110+ and buffer1 == buffer2
111+ )
94112
95- assert (check3 ["number" ] == 0
113+ assert (
114+ check3 ["number" ] == 0
96115 and check3 ["frac" ] == 30.0
97116 and check3 ["picture" ] is None
98117 and check3 ["img_file" ] is None
99- and isinstance (check3 ["params" ], np .ndarray ))
118+ and isinstance (check3 ["params" ], np .ndarray )
119+ )
100120
101121 assert check3 ["timestamp" ] > check2 ["timestamp" ]
102122 assert buffer1 == buffer2
103123 assert original_file_data == final_file_data
104124
105125
106- @raises (DataJointError )
107- def test_update1_nonexistent ():
108- Thing .update1 (dict (thing = 100 , frac = 0.5 )) # updating a non-existent entry
126+ def test_update1_nonexistent (
127+ enable_filepath_feature , schema_update1 , mock_stores_update
128+ ):
129+ with pytest .raises (DataJointError ):
130+ # updating a non-existent entry
131+ Thing .update1 (dict (thing = 100 , frac = 0.5 ))
109132
110133
111- @raises (DataJointError )
112- def test_update1_noprimary ():
113- Thing .update1 (dict (number = None )) # missing primary key
134+ def test_update1_noprimary (enable_filepath_feature , schema_update1 , mock_stores_update ):
135+ with pytest .raises (DataJointError ):
136+ # missing primary key
137+ Thing .update1 (dict (number = None ))
114138
115139
116- @raises (DataJointError )
117- def test_update1_misspelled_attribute ():
140+ def test_update1_misspelled_attribute (
141+ enable_filepath_feature , schema_update1 , mock_stores_update
142+ ):
118143 key = dict (thing = 17 )
119144 Thing .insert1 (dict (key , frac = 1.5 ))
120- Thing .update1 (dict (key , numer = 3 )) # misspelled attribute
145+ with pytest .raises (DataJointError ):
146+ # misspelled attribute
147+ Thing .update1 (dict (key , numer = 3 ))
0 commit comments