From 3556bcb1466e3f8ec95ffe83a54e8e704c096ec7 Mon Sep 17 00:00:00 2001 From: KopekC Date: Thu, 6 Mar 2025 16:49:58 -0500 Subject: [PATCH 1/2] chore: Integrate xAI as provider --- pyproject.toml | 3 ++- src/codegen/extensions/langchain/llm.py | 12 +++++++++++- uv.lock | 17 +++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea7fa4ae1..970ecf57e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "hatch-vcs>=0.4.0", "hatchling>=1.25.0", "pyinstrument>=5.0.0", - "pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work! + "pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work! "rich-click>=1.8.5", "python-dotenv>=1.0.1", "giturlparse", @@ -78,6 +78,7 @@ dependencies = [ "datasets", "colorlog>=6.9.0", "langsmith", + "langchain-xai>=0.2.1", ] license = { text = "Apache-2.0" } diff --git a/src/codegen/extensions/langchain/llm.py b/src/codegen/extensions/langchain/llm.py index 1aafce31c..d2cec02a3 100644 --- a/src/codegen/extensions/langchain/llm.py +++ b/src/codegen/extensions/langchain/llm.py @@ -13,6 +13,7 @@ from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI +from langchain_xai import ChatXAI from pydantic import Field @@ -76,6 +77,9 @@ def _get_model_kwargs(self) -> dict[str, Any]: if self.model_provider == "anthropic": return {**base_kwargs, "model": self.model_name} + elif self.model_provider == "xai": + xai_api_base = os.getenv("XAI_API_BASE", "https://api.x.ai/v1/") + return {**base_kwargs, "model": self.model_name, "xai_api_base": xai_api_base} else: # openai return {**base_kwargs, "model": self.model_name} @@ -93,7 +97,13 @@ def _get_model(self) -> BaseChatModel: raise ValueError(msg) return ChatOpenAI(**self._get_model_kwargs()) - msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai" + elif self.model_provider == "xai": + if not os.getenv("XAI_API_KEY"): + msg = "XAI_API_KEY not found in environment. Please set it in your .env file or environment variables." + raise ValueError(msg) + return ChatXAI(**self._get_model_kwargs()) + + msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai, xai" raise ValueError(msg) def _generate( diff --git a/uv.lock b/uv.lock index 86bc5b166..9796f29bb 100644 --- a/uv.lock +++ b/uv.lock @@ -561,6 +561,7 @@ dependencies = [ { name = "langchain-anthropic" }, { name = "langchain-core" }, { name = "langchain-openai" }, + { name = "langchain-xai" }, { name = "langgraph" }, { name = "langgraph-prebuilt" }, { name = "langsmith" }, @@ -689,6 +690,7 @@ requires-dist = [ { name = "langchain-anthropic", specifier = ">=0.3.7" }, { name = "langchain-core" }, { name = "langchain-openai" }, + { name = "langchain-xai" }, { name = "langgraph" }, { name = "langgraph-prebuilt" }, { name = "langsmith" }, @@ -2109,6 +2111,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/f8/6b82af988e65af9697f6a2f25373fb173fd32d48b62772a8773c5184c870/langchain_text_splitters-0.3.6-py3-none-any.whl", hash = "sha256:e5d7b850f6c14259ea930be4a964a65fa95d9df7e1dbdd8bad8416db72292f4e", size = 31197 }, ] +[[package]] +name = "langchain-xai" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "langchain-core" }, + { name = "langchain-openai" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/94/a633bf1b4bbf66e4516f4188adc1174480c465ae12fb98f06c3e23c98519/langchain_xai-0.2.1.tar.gz", hash = "sha256:143a6f52be7617b5e5c68ab10c9b7df90914f54a6b3098566ce22b5d8fd89da5", size = 7788 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/88/d8050e610fadabf97c1745d24f0987b3e53b72fca63c8038ab1e0c103da9/langchain_xai-0.2.1-py3-none-any.whl", hash = "sha256:87228125cb15131663979d627210fca47dcd6b9a28462e8b5fee47f73bbed9f4", size = 6263 }, +] + [[package]] name = "langgraph" version = "0.3.2" From 4da1f60639f9bf54bc06db3cf503c022d4c100aa Mon Sep 17 00:00:00 2001 From: kopekC <28070492+kopekC@users.noreply.github.com> Date: Thu, 6 Mar 2025 21:51:46 +0000 Subject: [PATCH 2/2] Automated pre-commit update --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 970ecf57e..777d80ec1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ "hatch-vcs>=0.4.0", "hatchling>=1.25.0", "pyinstrument>=5.0.0", - "pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work! + "pip>=24.3.1", # This is needed for some NPM/YARN/PNPM post-install scripts to work! "rich-click>=1.8.5", "python-dotenv>=1.0.1", "giturlparse",