@@ -1370,32 +1370,55 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
13701370class ScatterPlot (MPLPlot ):
13711371 _layout_type = 'single'
13721372
1373- def __init__ (self , data , x , y , ** kwargs ):
1373+ def __init__ (self , data , x , y , c = None , ** kwargs ):
13741374 MPLPlot .__init__ (self , data , ** kwargs )
1375- self .kwds .setdefault ('c' , self .plt .rcParams ['patch.facecolor' ])
13761375 if x is None or y is None :
13771376 raise ValueError ( 'scatter requires and x and y column' )
13781377 if com .is_integer (x ) and not self .data .columns .holds_integer ():
13791378 x = self .data .columns [x ]
13801379 if com .is_integer (y ) and not self .data .columns .holds_integer ():
13811380 y = self .data .columns [y ]
1381+ if com .is_integer (c ) and not self .data .columns .holds_integer ():
1382+ c = self .data .columns [c ]
13821383 self .x = x
13831384 self .y = y
1385+ self .c = c
13841386
13851387 @property
13861388 def nseries (self ):
13871389 return 1
13881390
13891391 def _make_plot (self ):
1390- x , y , data = self .x , self .y , self .data
1392+ import matplotlib .pyplot as plt
1393+
1394+ x , y , c , data = self .x , self .y , self .c , self .data
13911395 ax = self .axes [0 ]
13921396
1397+ # plot a colorbar only if a colormap is provided or necessary
1398+ cb = self .kwds .pop ('colorbar' , self .colormap or c in self .data .columns )
1399+
1400+ # pandas uses colormap, matplotlib uses cmap.
1401+ cmap = self .colormap or 'RdBu'
1402+ cmap = plt .cm .get_cmap (cmap )
1403+
1404+ if c is None :
1405+ c_values = self .plt .rcParams ['patch.facecolor' ]
1406+ elif c in self .data .columns :
1407+ c_values = self .data [c ].values
1408+ else :
1409+ c_values = c
1410+
13931411 if self .legend and hasattr (self , 'label' ):
13941412 label = self .label
13951413 else :
13961414 label = None
1397- scatter = ax .scatter (data [x ].values , data [y ].values , label = label ,
1398- ** self .kwds )
1415+ scatter = ax .scatter (data [x ].values , data [y ].values , c = c_values ,
1416+ label = label , cmap = cmap , ** self .kwds )
1417+ if cb :
1418+ img = ax .collections [0 ]
1419+ cb_label = c if c in self .data .columns else ''
1420+ self .fig .colorbar (img , ax = ax , label = cb_label )
1421+
13991422 self ._add_legend_handle (scatter , label )
14001423
14011424 errors_x = self ._get_errorbars (label = x , index = 0 , yerr = False )
@@ -2261,6 +2284,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
22612284 colormap : str or matplotlib colormap object, default None
22622285 Colormap to select colors from. If string, load colormap with that name
22632286 from matplotlib.
2287+ colorbar : boolean, optional
2288+ If True, plot colorbar (only relevant for 'scatter' and 'hexbin' plots)
22642289 position : float
22652290 Specify relative alignments for bar plot layout.
22662291 From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center)
@@ -2287,6 +2312,9 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
22872312 `C` specifies the value at each `(x, y)` point and `reduce_C_function`
22882313 is a function of one argument that reduces all the values in a bin to
22892314 a single number (e.g. `mean`, `max`, `sum`, `std`).
2315+
2316+ If `kind`='scatter' and the argument `c` is the name of a dataframe column,
2317+ the values of that column are used to color each point.
22902318 """
22912319
22922320 kind = _get_standard_kind (kind .lower ().strip ())
0 commit comments