bohrium._base_client 源代码

import logging
import platform
import time
from functools import wraps
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Generic,
    Literal,
    Mapping,
    Optional,
    TypeVar,
    Union,
    cast,
)
from urllib.parse import urljoin

import distro
import httpx
from httpx import URL, Limits, Timeout

from ._constants import DEFAULT_CONNECTION_LIMITS, DEFAULT_MAX_RETRIES, DEFAULT_TIMEOUT
from ._utils import lru_cache

logger = logging.getLogger(__name__)

Arch = Union[Literal["x32", "x64", "arm", "arm64", "unknown"]]

Platform = Union[
    Literal[
        "MacOS",
        "Linux",
        "Windows",
        "FreeBSD",
        "OpenBSD",
        "iOS",
        "Android",
        "Unknown",
    ],
]


class _DefaultHttpxClient(httpx.Client):
    def __init__(self, **kwargs: Any) -> None:
        kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
        kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
        kwargs.setdefault("follow_redirects", True)
        super().__init__(**kwargs)


if TYPE_CHECKING:
    DefaultHttpxClient = httpx.Client
    """An alias to `httpx.Client` that provides the same defaults that this SDK
    uses internally.

    This is useful because overriding the `http_client` with your own instance of
    `httpx.Client` will result in httpx's defaults being used, not ours.
    """
else:
    DefaultHttpxClient = _DefaultHttpxClient

_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])


class SyncHttpxClientWrapper(DefaultHttpxClient):
    def __del__(self) -> None:
        try:
            self.close()
        except Exception:
            pass


def retry(max_retries_func, delay=1, exceptions=(Exception,)):
    """
    Decorator for retrying a function upon exception.

    max_retries: The maximum number of retries before giving up.
    delay: Initial delay between retries in seconds.
    backoff: Multiplier applied to delay between retries (e.g., for exponential backoff).
    exceptions: Tuple of exceptions to catch. Defaults to base Exception class.
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            self = args[0]
            max_retries = max_retries_func(self)
            retries = 0
            while retries < max_retries:
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    if retries >= max_retries:
                        logger.error(
                            f"Exhausted all retries. Raising the last caught exception: {e}"
                        )
                        raise e
                    else:
                        logger.warning(
                            f"Retry {retries} due to error: {e}. Retrying in {delay} seconds..."
                        )
                        time.sleep(delay)
                    retries += 1

        return wrapper

    return decorator


[文档] class BaseClient(Generic[_HttpxClientT]): _client: _HttpxClientT _base_url: URL max_retries: int timeout: Union[float, Timeout, None] _limits: httpx.Limits _version: Union[str, None] def __init__( self, *, _version: Optional[str] = None, base_url: Union[str, URL], limits: httpx.Limits, max_retries: int = DEFAULT_MAX_RETRIES, timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT, custom_headers: Optional[Mapping[str, str]] = None, ): self._base_url = self._enforce_trailing_slash(URL(base_url)) self._version = _version self.max_retries = max_retries self.timeout = timeout self._limits = limits self._custom_headers = custom_headers or {} self.backoff_base = 2 def _enforce_trailing_slash(self, url: URL) -> URL: if url.raw_path.endswith(b"/"): return url return url.copy_with(raw_path=url.raw_path + b"/") def _build_headers(self, custom_headers) -> httpx.Headers: headers_dict = _merge_mappings( self.default_headers, self._custom_headers, custom_headers ) # 过滤掉 value 为 None 的 header headers_dict = {k: v for k, v in headers_dict.items() if v is not None} headers = httpx.Headers(headers_dict) return headers or dict() def _build_params(self, custom_params) -> dict[str, str]: params = _merge_mappings(self.default_params, custom_params) return params @property def custom_auth(self) -> Union[httpx.Auth, None]: return None @property def auth_headers(self) -> dict[str, str]: return {} @property def default_headers(self) -> dict[str, str]: return { "Accept": "application/json", "Content-Type": "application/json", } @property def default_params(self) -> dict[str, str]: return {"accessKey": self.access_key}
[文档] def platform_headers(self) -> Dict[str, str]: return platform_headers(self._version)
@retry( max_retries_func=lambda self: self.max_retries, exceptions=(httpx.RequestError,), ) def _request( self, method: str, path: str, json=None, headers=None, data=None, **kwargs ) -> httpx.Response: url = urljoin(str(self._base_url), path) logger.info(f"Requesting {method} {url}") merged_headers = self._build_headers(headers) merged_params = self._build_params(kwargs.get("params")) # 处理文件上传 request_kwargs = { "method": method.upper(), "url": url, "params": merged_params, } # 处理超时参数 if "timeout" in kwargs: request_kwargs["timeout"] = kwargs["timeout"] if json is not None: request_kwargs["json"] = json request_kwargs["headers"] = merged_headers elif "files" in kwargs: # 当有files参数时,不使用json参数,而是使用files和data # 不设置headers,让httpx自动处理multipart/form-data request_kwargs["files"] = kwargs["files"] if "data" in kwargs: request_kwargs["data"] = kwargs["data"] elif "data" in kwargs: request_kwargs["data"] = kwargs["data"] request_kwargs["headers"] = merged_headers else: request_kwargs["headers"] = merged_headers try: return self._client.request( method.upper(), url, json=json, data=data, headers=merged_headers, params=merged_params, ) except httpx.TransportError as e: logger.error(f"Transport error: {e}.") raise e except httpx.RequestError as e: logger.error(f"Request failed with error: {e}.")
[文档] def get(self, path: str, headers=None, **kwargs) -> httpx.Response: """Make a GET request.""" return self._request("GET", path, headers=headers, **kwargs)
[文档] def post(self, path: str, json=None, headers=None, **kwargs) -> httpx.Response: """Make a POST request.""" return self._request("POST", path, json=json, headers=headers, **kwargs)
[文档] def patch(self, path: str, json=None, headers=None, **kwargs) -> httpx.Response: """Make a PATCH request.""" return self._request("PATCH", path, json=json, headers=headers, **kwargs)
[文档] def put(self, path: str, json=None, headers=None, **kwargs) -> httpx.Response: """Make a PUT request.""" return self._request("PUT", path, json=json, headers=headers, **kwargs)
[文档] def delete(self, path: str, headers=None, **kwargs) -> httpx.Response: """Make a DELETE request.""" return self._request("DELETE", path, headers=headers, **kwargs)
[文档] def close(self): """Close the underlying HTTPX client.""" self.client.close()
[文档] class SyncAPIClient(BaseClient[httpx.Client]): _client: httpx.Client def __init__( self, base_url: Union[str, httpx.URL], max_retries: int = DEFAULT_MAX_RETRIES, timeout: Union[float, Timeout, None] = DEFAULT_TIMEOUT, limits: Limits = DEFAULT_CONNECTION_LIMITS, _version: Optional[str] = None, http_client: Optional[httpx.Client] = None, custom_headers: Optional[Mapping[str, str]] = None, ) -> None: if http_client is not None and not isinstance( http_client, httpx.Client ): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError( f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}" ) print(timeout) super().__init__( _version=_version, limits=limits, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), base_url=base_url, max_retries=max_retries, custom_headers=custom_headers, ) self._client = http_client or SyncHttpxClientWrapper( base_url=base_url, # cast to a valid type because mypy doesn't understand our type narrowing timeout=cast(Timeout, timeout), limits=limits, follow_redirects=True, )
[文档] def is_closed(self) -> bool: return self._client.is_closed
[文档] def close(self) -> None: """Close the underlying HTTPX client. The client will *not* be usable after this. """ # If an error is thrown while constructing a client, self._client # may not be present if hasattr(self, "_client"): self._client.close()
def _prepare_request( self, request: httpx.Request, # noqa: ARG002 ) -> None: """This method is used as a callback for mutating the `Request` object after it has been constructed. This is useful for cases where you want to add certain headers based off of the request properties, e.g. `url`, `method` etc. """ return None
def _merge_mappings(*mappings): merged = {} for mapping in mappings: if mapping: merged.update(mapping) return merged
[文档] class AsyncAPIClient(BaseClient[httpx.AsyncClient]): pass
@lru_cache(maxsize=None) def get_platform() -> Platform: try: system = platform.system().lower() platform_name = platform.platform().lower() except Exception: return "Unknown" if "iphone" in platform_name or "ipad" in platform_name: # Tested using Python3IDE on an iPhone 11 and Pythonista on an iPad 7 # system is Darwin and platform_name is a string like: # - Darwin-21.6.0-iPhone12,1-64bit # - Darwin-21.6.0-iPad7,11-64bit return "iOS" if system == "darwin": return "MacOS" if system == "windows": return "Windows" if "android" in platform_name: # Tested using Pydroid 3 # system is Linux and platform_name is a string like 'Linux-5.10.81-android12-9-00001-geba40aecb3b7-ab8534902-aarch64-with-libc' return "Android" if system == "linux": # https://distro.readthedocs.io/en/latest/#distro.id distro_id = distro.id() if distro_id == "freebsd": return "FreeBSD" if distro_id == "openbsd": return "OpenBSD" return "Linux" return "Unknown" def platform_headers(version: str) -> Dict[str, str]: return { "X-Stainless-Lang": "python", "X-Stainless-Package-Version": version, "X-Stainless-OS": str(get_platform()), "X-Stainless-Arch": str(get_architecture()), "X-Stainless-Runtime": get_python_runtime(), "X-Stainless-Runtime-Version": get_python_version(), } def get_architecture() -> Arch: try: python_bitness, _ = platform.architecture() machine = platform.machine().lower() except Exception: return "unknown" if machine in ("arm64", "aarch64"): return "arm64" # TODO: untested if machine == "arm": return "arm" if machine == "x86_64": return "x64" # TODO: untested if python_bitness == "32bit": return "x32" return "unknown" def get_python_runtime() -> str: try: return platform.python_implementation() except Exception: return "unknown" def get_python_version() -> str: try: return platform.python_version() except Exception: return "unknown"