99You can find more information on http://presbrey.mit.edu/PyDTA and
1010http://statsmodels.sourceforge.net/devel/
1111"""
12- # TODO: Fix this module so it can use cross-compatible zip, map, and range
1312import numpy as np
1413
1514import sys
2019from pandas .core .categorical import Categorical
2120import datetime
2221from pandas import compat
23- from pandas .compat import long , lrange , lmap , lzip , text_type , string_types
22+ from pandas .compat import lrange , lmap , lzip , text_type , string_types , range , \
23+ zip
2424from pandas import isnull
2525from pandas .io .common import get_filepath_or_buffer
2626from pandas .lib import max_len_string_array , is_string_array
2727from pandas .tslib import NaT
2828
2929def read_stata (filepath_or_buffer , convert_dates = True ,
30- convert_categoricals = True , encoding = None , index = None ):
30+ convert_categoricals = True , encoding = None , index = None ,
31+ convert_missing = False ):
3132 """
3233 Read Stata file into DataFrame
3334
@@ -44,10 +45,19 @@ def read_stata(filepath_or_buffer, convert_dates=True,
4445 support unicode. None defaults to cp1252.
4546 index : identifier of index column
4647 identifier of column that should be used as index of the DataFrame
48+ convert_missing : boolean, defaults to False
49+ Flag indicating whether to convert missing values to their Stata
50+ representations. If False, missing values are replaced with nans.
51+ If True, columns containing missing values are returned with
52+ object data types and missing values are represented by
53+ StataMissingValue objects.
4754 """
4855 reader = StataReader (filepath_or_buffer , encoding )
4956
50- return reader .data (convert_dates , convert_categoricals , index )
57+ return reader .data (convert_dates ,
58+ convert_categoricals ,
59+ index ,
60+ convert_missing )
5161
5262_date_formats = ["%tc" , "%tC" , "%td" , "%d" , "%tw" , "%tm" , "%tq" , "%th" , "%ty" ]
5363
@@ -291,35 +301,76 @@ class StataMissingValue(StringMixin):
291301
292302 Parameters
293303 -----------
294- offset
295- value
304+ value : int8, int16, int32, float32 or float64
305+ The Stata missing value code
296306
297307 Attributes
298308 ----------
299- string
300- value
309+ string : string
310+ String representation of the Stata missing value
311+ value : int8, int16, int32, float32 or float64
312+ The original encoded missing value
301313
302314 Notes
303315 -----
304316 More information: <http://www.stata.com/help.cgi?missing>
317+
318+ Integer missing values make the code '.', '.a', ..., '.z' to the ranges
319+ 101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
320+ 2147483647 (for int32). Missing values for floating point data types are
321+ more complex but the pattern is simple to discern from the following table.
322+
323+ np.float32 missing values (float in Stata)
324+ 0000007f .
325+ 0008007f .a
326+ 0010007f .b
327+ ...
328+ 00c0007f .x
329+ 00c8007f .y
330+ 00d0007f .z
331+
332+ np.float64 missing values (double in Stata)
333+ 000000000000e07f .
334+ 000000000001e07f .a
335+ 000000000002e07f .b
336+ ...
337+ 000000000018e07f .x
338+ 000000000019e07f .y
339+ 00000000001ae07f .z
305340 """
306- # TODO: Needs test
307- def __init__ (self , offset , value ):
341+
342+ # Construct a dictionary of missing values
343+ MISSING_VALUES = {}
344+ bases = (101 , 32741 , 2147483621 )
345+ for b in bases :
346+ MISSING_VALUES [b ] = '.'
347+ for i in range (1 , 27 ):
348+ MISSING_VALUES [i + b ] = '.' + chr (96 + i )
349+
350+ base = b'\x00 \x00 \x00 \x7f '
351+ increment = struct .unpack ('<i' , b'\x00 \x08 \x00 \x00 ' )[0 ]
352+ for i in range (27 ):
353+ value = struct .unpack ('<f' , base )[0 ]
354+ MISSING_VALUES [value ] = '.'
355+ if i > 0 :
356+ MISSING_VALUES [value ] += chr (96 + i )
357+ int_value = struct .unpack ('<i' , struct .pack ('<f' , value ))[0 ] + increment
358+ base = struct .pack ('<i' , int_value )
359+
360+ base = b'\x00 \x00 \x00 \x00 \x00 \x00 \xe0 \x7f '
361+ increment = struct .unpack ('q' , b'\x00 \x00 \x00 \x00 \x00 \x01 \x00 \x00 ' )[0 ]
362+ for i in range (27 ):
363+ value = struct .unpack ('<d' , base )[0 ]
364+ MISSING_VALUES [value ] = '.'
365+ if i > 0 :
366+ MISSING_VALUES [value ] += chr (96 + i )
367+ int_value = struct .unpack ('q' , struct .pack ('<d' , value ))[0 ] + increment
368+ base = struct .pack ('q' , int_value )
369+
370+ def __init__ (self , value ):
308371 self ._value = value
309- value_type = type (value )
310- if value_type in int :
311- loc = value - offset
312- elif value_type in (float , np .float32 , np .float64 ):
313- if value <= np .finfo (np .float32 ).max : # float32
314- conv_str , byte_loc , scale = '<f' , 1 , 8
315- else :
316- conv_str , byte_loc , scale = '<d' , 5 , 1
317- value_bytes = struct .pack (conv_str , value )
318- loc = (struct .unpack ('<b' , value_bytes [byte_loc ])[0 ] / scale ) + 0
319- else :
320- # Should never be hit
321- loc = 0
322- self ._str = loc is 0 and '.' or ('.' + chr (loc + 96 ))
372+ self ._str = self .MISSING_VALUES [value ]
373+
323374 string = property (lambda self : self ._str ,
324375 doc = "The Stata representation of the missing value: "
325376 "'.', '.a'..'.z'" )
@@ -333,6 +384,10 @@ def __repr__(self):
333384 # not perfect :-/
334385 return "%s(%s)" % (self .__class__ , self )
335386
387+ def __eq__ (self , other ):
388+ return (isinstance (other , self .__class__ )
389+ and self .string == other .string and self .value == other .value )
390+
336391
337392class StataParser (object ):
338393 _default_encoding = 'cp1252'
@@ -711,15 +766,7 @@ def _col_size(self, k=None):
711766 return self .col_sizes [k ]
712767
713768 def _unpack (self , fmt , byt ):
714- d = struct .unpack (self .byteorder + fmt , byt )[0 ]
715- if fmt [- 1 ] in self .VALID_RANGE :
716- nmin , nmax = self .VALID_RANGE [fmt [- 1 ]]
717- if d < nmin or d > nmax :
718- if self ._missing_values :
719- return StataMissingValue (nmax , d )
720- else :
721- return None
722- return d
769+ return struct .unpack (self .byteorder + fmt , byt )[0 ]
723770
724771 def _null_terminate (self , s ):
725772 if compat .PY3 or self ._encoding is not None : # have bytes not strings,
@@ -752,16 +799,15 @@ def _next(self):
752799 )
753800 return data
754801 else :
755- return list (
756- map (
802+ return lmap (
757803 lambda i : self ._unpack (typlist [i ],
758804 self .path_or_buf .read (
759805 self ._col_size (i )
760806 )),
761807 range (self .nvar )
762- )
763808 )
764809
810+
765811 def _dataset (self ):
766812 """
767813 Returns a Python generator object for iterating over the dataset.
@@ -853,7 +899,8 @@ def _read_strls(self):
853899 self .GSO [v_o ] = self .path_or_buf .read (length - 1 )
854900 self .path_or_buf .read (1 ) # zero-termination
855901
856- def data (self , convert_dates = True , convert_categoricals = True , index = None ):
902+ def data (self , convert_dates = True , convert_categoricals = True , index = None ,
903+ convert_missing = False ):
857904 """
858905 Reads observations from Stata file, converting them into a dataframe
859906
@@ -866,11 +913,18 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None):
866913 variables
867914 index : identifier of index column
868915 identifier of column that should be used as index of the DataFrame
916+ convert_missing : boolean, defaults to False
917+ Flag indicating whether to convert missing values to their Stata
918+ representation. If False, missing values are replaced with
919+ nans. If True, columns containing missing values are returned with
920+ object data types and missing values are represented by
921+ StataMissingValue objects.
869922
870923 Returns
871924 -------
872925 y : DataFrame instance
873926 """
927+ self ._missing_values = convert_missing
874928 if self ._data_read :
875929 raise Exception ("Data has already been read." )
876930 self ._data_read = True
@@ -894,18 +948,62 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None):
894948 if convert_categoricals :
895949 self ._read_value_labels ()
896950
951+ # TODO: Refactor to use a dictionary constructor and the correct dtype from the start?
897952 if len (data )== 0 :
898953 data = DataFrame (columns = self .varlist , index = index )
899954 else :
900955 data = DataFrame (data , columns = self .varlist , index = index )
901956
902957 cols_ = np .where (self .dtyplist )[0 ]
958+
959+ # Convert columns (if needed) to match input type
960+ index = data .index
961+ requires_type_conversion = False
962+ data_formatted = []
903963 for i in cols_ :
904964 if self .dtyplist [i ] is not None :
905965 col = data .columns [i ]
906- if data [col ].dtype is not np .dtype (object ):
907- data [col ] = Series (data [col ], data [col ].index ,
908- self .dtyplist [i ])
966+ dtype = data [col ].dtype
967+ if (dtype != np .dtype (object )) and (dtype != self .dtyplist [i ]):
968+ requires_type_conversion = True
969+ data_formatted .append ((col , Series (data [col ], index , self .dtyplist [i ])))
970+ else :
971+ data_formatted .append ((col , data [col ]))
972+ if requires_type_conversion :
973+ data = DataFrame .from_items (data_formatted )
974+ del data_formatted
975+
976+ # Check for missing values, and replace if found
977+ for i , colname in enumerate (data ):
978+ fmt = self .typlist [i ]
979+ if fmt not in self .VALID_RANGE :
980+ continue
981+
982+ nmin , nmax = self .VALID_RANGE [fmt ]
983+ series = data [colname ]
984+ missing = np .logical_or (series < nmin , series > nmax )
985+
986+ if not missing .any ():
987+ continue
988+
989+ if self ._missing_values : # Replacement follows Stata notation
990+ missing_loc = np .argwhere (missing )
991+ umissing , umissing_loc = np .unique (series [missing ],
992+ return_inverse = True )
993+ replacement = Series (series , dtype = np .object )
994+ for i , um in enumerate (umissing ):
995+ missing_value = StataMissingValue (um )
996+
997+ loc = missing_loc [umissing_loc == i ]
998+ replacement .iloc [loc ] = missing_value
999+ else : # All replacements are identical
1000+ dtype = series .dtype
1001+ if dtype not in (np .float32 , np .float64 ):
1002+ dtype = np .float64
1003+ replacement = Series (series , dtype = dtype )
1004+ replacement [missing ] = np .nan
1005+
1006+ data [colname ] = replacement
9091007
9101008 if convert_dates :
9111009 cols = np .where (lmap (lambda x : x in _date_formats ,
0 commit comments