2424from typing_extensions import Self , TypeAlias
2525
2626import aiohttp
27- import packaging .specifiers
2827import packaging .version
2928import tomlkit
29+ from packaging .specifiers import Specifier
3030from termcolor import colored
3131
3232from ts_utils .metadata import StubMetadata , metadata_path , read_metadata , stubs_path
@@ -121,14 +121,21 @@ async def fetch_pypi_info(distribution: str, session: aiohttp.ClientSession) ->
121121@dataclass
122122class Update :
123123 distribution : str
124- old_version_spec : str
125- new_version_spec : str
124+ old_version_spec : Specifier
125+ new_version_spec : Specifier
126126 links : dict [str , str ]
127127 diff_analysis : DiffAnalysis | None
128128
129129 def __str__ (self ) -> str :
130130 return f"Updating { self .distribution } from { self .old_version_spec !r} to { self .new_version_spec !r} "
131131
132+ @property
133+ def new_version (self ) -> str :
134+ if self .new_version_spec .operator == "==" :
135+ return str (self .new_version_spec )[2 :]
136+ else :
137+ return str (self .new_version_spec )
138+
132139
133140@dataclass
134141class Obsolete :
@@ -239,12 +246,7 @@ async def find_first_release_with_py_typed(pypi_info: PypiInfo, *, session: aioh
239246 return first_release_with_py_typed
240247
241248
242- def _check_spec (updated_spec : str , version : packaging .version .Version ) -> str :
243- assert version in packaging .specifiers .SpecifierSet (f"=={ updated_spec } " ), f"{ version } not in { updated_spec } "
244- return updated_spec
245-
246-
247- def get_updated_version_spec (spec : str , version : packaging .version .Version ) -> str :
249+ def get_updated_version_spec (spec : Specifier , version : packaging .version .Version ) -> Specifier :
248250 """
249251 Given the old specifier and an updated version, returns an updated specifier that has the
250252 specificity of the old specifier, but matches the updated version.
@@ -256,15 +258,22 @@ def get_updated_version_spec(spec: str, version: packaging.version.Version) -> s
256258 spec="1.*", version="2.3.4" -> "2.*"
257259 spec="1.1.*", version="1.2.3" -> "1.2.*"
258260 spec="1.1.1.*", version="1.2.3" -> "1.2.3.*"
261+ spec="~=1.0.1", version="1.0.3" -> "~=1.0.3"
262+ spec="~=1.0.1", version="1.1.0" -> "~=1.1.0"
259263 """
260- if not spec .endswith (".*" ):
261- return _check_spec (str (version ), version )
262-
263- specificity = spec .count ("." ) if spec .removesuffix (".*" ) else 0
264- rounded_version = version .base_version .split ("." )[:specificity ]
265- rounded_version .extend (["0" ] * (specificity - len (rounded_version )))
266-
267- return _check_spec ("." .join (rounded_version ) + ".*" , version )
264+ if spec .operator == "==" and spec .version .endswith (".*" ):
265+ specificity = spec .version .count ("." ) if spec .version .removesuffix (".*" ) else 0
266+ rounded_version = version .base_version .split ("." )[:specificity ]
267+ rounded_version .extend (["0" ] * (specificity - len (rounded_version )))
268+ updated_spec = Specifier ("==" + "." .join (rounded_version ) + ".*" )
269+ elif spec .operator == "==" :
270+ updated_spec = Specifier (f"=={ version } " )
271+ elif spec .operator == "~=" :
272+ updated_spec = Specifier (f"~={ version } " )
273+ else :
274+ raise ValueError (f"Unsupported version operator: { spec .operator } " )
275+ assert version in updated_spec , f"{ version } not in { updated_spec } "
276+ return updated_spec
268277
269278
270279@functools .cache
@@ -333,15 +342,13 @@ async def get_diff_info(
333342 with contextlib .suppress (packaging .version .InvalidVersion ):
334343 versions_to_tags [packaging .version .Version (tag_name )] = tag_name
335344
336- curr_specifier = packaging .specifiers .SpecifierSet (f"=={ stub_info .version } " )
337-
338345 try :
339346 new_tag = versions_to_tags [pypi_version ]
340347 except KeyError :
341348 return None
342349
343350 try :
344- old_version = max (version for version in versions_to_tags if version in curr_specifier and version < pypi_version )
351+ old_version = max (version for version in versions_to_tags if version in stub_info . version_spec and version < pypi_version )
345352 except ValueError :
346353 return None
347354 else :
@@ -472,9 +479,8 @@ async def determine_action(distribution: str, session: aiohttp.ClientSession) ->
472479 pypi_info = await fetch_pypi_info (stub_info .distribution , session )
473480 latest_release = pypi_info .get_latest_release ()
474481 latest_version = latest_release .version
475- spec = packaging .specifiers .SpecifierSet (f"=={ stub_info .version } " )
476482 obsolete_since = await find_first_release_with_py_typed (pypi_info , session = session )
477- if obsolete_since is None and latest_version in spec :
483+ if obsolete_since is None and latest_version in stub_info . version_spec :
478484 return NoUpdate (stub_info .distribution , "up to date" )
479485
480486 relevant_version = obsolete_since .version if obsolete_since else latest_version
@@ -514,8 +520,8 @@ async def determine_action(distribution: str, session: aiohttp.ClientSession) ->
514520
515521 return Update (
516522 distribution = stub_info .distribution ,
517- old_version_spec = stub_info .version ,
518- new_version_spec = get_updated_version_spec (stub_info .version , latest_version ),
523+ old_version_spec = stub_info .version_spec ,
524+ new_version_spec = get_updated_version_spec (stub_info .version_spec , latest_version ),
519525 links = links ,
520526 diff_analysis = diff_analysis ,
521527 )
@@ -678,13 +684,13 @@ def get_update_pr_body(update: Update, metadata: dict[str, Any]) -> str:
678684async def suggest_typeshed_update (update : Update , session : aiohttp .ClientSession , action_level : ActionLevel ) -> None :
679685 if action_level <= ActionLevel .nothing :
680686 return
681- title = f"[stubsabot] Bump { update .distribution } to { update .new_version_spec } "
687+ title = f"[stubsabot] Bump { update .distribution } to { update .new_version } "
682688 async with _repo_lock :
683689 branch_name = f"{ BRANCH_PREFIX } /{ normalize (update .distribution )} "
684690 subprocess .check_call (["git" , "checkout" , "-B" , branch_name , "origin/main" ])
685691 with metadata_path (update .distribution ).open ("rb" ) as f :
686692 meta = tomlkit .load (f )
687- meta ["version" ] = update .new_version_spec
693+ meta ["version" ] = update .new_version
688694 with metadata_path (update .distribution ).open ("w" , encoding = "UTF-8" ) as f :
689695 # tomlkit.dump has partially unknown IO type
690696 tomlkit .dump (meta , f ) # pyright: ignore[reportUnknownMemberType]
0 commit comments