diff --git a/pyproject.toml b/pyproject.toml index fed528d..0ea3e89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools"] +requires = ["setuptools","Shapely"] build-backend = "setuptools.build_meta" diff --git a/requirements.test.txt b/requirements.test.txt index b3eaa8c..407033e 100644 --- a/requirements.test.txt +++ b/requirements.test.txt @@ -1,2 +1,3 @@ pytest==3.2.5 setuptools +Shapely diff --git a/shapefile.py b/shapefile.py index 0cd8b9e..6ba074c 100644 --- a/shapefile.py +++ b/shapefile.py @@ -8,6 +8,7 @@ __version__ = "2.3.0" +from pickle import NONE from struct import pack, unpack, calcsize, error, Struct import os import sys @@ -18,6 +19,8 @@ import io from datetime import date import zipfile +from shapely.geometry import Point +from shapely.geometry.polygon import Polygon # Create named logger logger = logging.getLogger(__name__) @@ -1296,7 +1299,7 @@ def __shpHeader(self): else: self.mbox.append(None) - def __shape(self, oid=None, bbox=None): + def __shape(self, oid=None, bbox=None, pnt=None): """Returns the header info and geometry for a single shape.""" f = self.__getFileObj(self.shp) record = Shape(oid=oid) @@ -1334,6 +1337,12 @@ def __shape(self, oid=None, bbox=None): if nPoints: flat = unpack("<%sd" % (2 * nPoints), f.read(16*nPoints)) record.points = list(izip(*(iter(flat),) * 2)) + + if pnt is not None and not (Polygon(record.points).contains(Point(pnt))): + # because we stop parsing this shape, skip to beginning of + # next shape before we return + f.seek(next) + return None # Read z extremes and values if shapeType in (13,15,18,31): (zmin, zmax) = unpack("<2d", f.read(16)) @@ -1456,7 +1465,7 @@ def shapes(self, bbox=None): shapes.extend(self.iterShapes(bbox=bbox)) return shapes - def iterShapes(self, bbox=None): + def iterShapes(self, bbox=None, pnt=None): """Returns a generator of shapes in a shapefile. Useful for handling large shapefiles. To only read shapes within a given spatial region, specify the 'bbox' @@ -1475,7 +1484,7 @@ def iterShapes(self, bbox=None): # Iterate exactly the number of shapes from shx header for i in xrange(self.numShapes): # MAYBE: check if more left of file or exit early? - shape = self.__shape(oid=i, bbox=bbox) + shape = self.__shape(oid=i, bbox=bbox, pnt=pnt) if shape: yield shape else: @@ -1487,7 +1496,7 @@ def iterShapes(self, bbox=None): pos = shp.tell() while pos < shpLength: offsets.append(pos) - shape = self.__shape(oid=i, bbox=bbox) + shape = self.__shape(oid=i, bbox=bbox, pnt=pnt) pos = shp.tell() if shape: yield shape @@ -1753,7 +1762,7 @@ def shapeRecords(self, fields=None, bbox=None): """ return ShapeRecords(self.iterShapeRecords(fields=fields, bbox=bbox)) - def iterShapeRecords(self, fields=None, bbox=None): + def iterShapeRecords(self, fields=None, bbox=None, pnt=None): """Returns a generator of combination geometry/attribute records for all records in a shapefile. To only read some of the fields, specify the 'fields' arg as a @@ -1761,7 +1770,7 @@ def iterShapeRecords(self, fields=None, bbox=None): To only read entries within a given spatial region, specify the 'bbox' arg as a list or tuple of xmin,ymin,xmax,ymax. """ - if bbox is None: + if bbox is None and pnt is None: # iterate through all shapes and records for shape, record in izip(self.iterShapes(), self.iterRecords(fields=fields)): yield ShapeRecord(shape=shape, record=record) @@ -1771,7 +1780,7 @@ def iterShapeRecords(self, fields=None, bbox=None): # make sure to seek to correct file location... #fieldTuples,recLookup,recStruct = self.__recordFields(fields) - for shape in self.iterShapes(bbox=bbox): + for shape in self.iterShapes(bbox=bbox, pnt=pnt): if shape: #record = self.__record(oid=i, fieldTuples=fieldTuples, recLookup=recLookup, recStruct=recStruct) record = self.record(i=shape.oid, fields=fields) diff --git a/test_shapefile.py b/test_shapefile.py index d1ac294..48162ed 100644 --- a/test_shapefile.py +++ b/test_shapefile.py @@ -11,6 +11,8 @@ import pytest import json import datetime +from shapely.geometry import Point +from shapely.geometry.polygon import Polygon if sys.version_info.major == 2: # required by pytest for python <36 from pathlib2 import Path @@ -1010,6 +1012,25 @@ def test_bboxfilter_itershaperecords(): assert shaperec.record.oid == man.record.oid assert shaperec.record == man.record +def test_point_in_poly(): + """ + + """ + pnt = [-122.4, 37.8] + with shapefile.Reader("shapefiles/blockgroups") as sf: + # apply pnt filter + shapes = list(sf.iterShapes(pnt=pnt)) + # manually check pnts + manual = shapefile.Shapes() + for shape in sf.iterShapes(): + if Polygon(shape.points).contains(Point(pnt)): + manual.append(shape) + # compare + assert len(shapes) == len(manual) + # check that they line up + for shape,man in zip(shapes,manual): + assert shape.oid == man.oid + assert shape.__geo_interface__ == man.__geo_interface__ def test_shaperecords_shaperecord(): """