From 4c9e86b208c039a6d780b2161612ebf0d53125fe Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 13 Nov 2025 23:44:40 +0530 Subject: [PATCH 01/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration/constants.py | 2 +- plugins/mamba-cp/.DS_Store | Bin 0 -> 6148 bytes plugins/mamba-cp/.isort.cfg | 10 + plugins/mamba-cp/.pylintrc | 649 ++++++++++++++++++ plugins/mamba-cp/README.md | 1 + plugins/mamba-cp/configs/mcp.yaml | 8 + plugins/mamba-cp/pyproject.toml | 29 + .../src/fms_acceleration_mcp/.DS_Store | Bin 0 -> 6148 bytes .../src/fms_acceleration_mcp/__init__.py | 20 + .../src/fms_acceleration_mcp/callback.py | 38 + .../framework_plugin_mcp.py | 90 +++ .../src/fms_acceleration_mcp/patch.py | 233 +++++++ .../fms_acceleration_mcp/utils/__init__.py | 17 + .../src/fms_acceleration_mcp/utils/utils.py | 89 +++ plugins/mamba-cp/tox.ini | 50 ++ 15 files changed, 1235 insertions(+), 1 deletion(-) create mode 100644 plugins/mamba-cp/.DS_Store create mode 100644 plugins/mamba-cp/.isort.cfg create mode 100644 plugins/mamba-cp/.pylintrc create mode 100644 plugins/mamba-cp/README.md create mode 100644 plugins/mamba-cp/configs/mcp.yaml create mode 100644 plugins/mamba-cp/pyproject.toml create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/.DS_Store create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/callback.py create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/patch.py create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py create mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py create mode 100644 plugins/mamba-cp/tox.ini diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py index 252842e0..0e8f522d 100644 --- a/plugins/framework/src/fms_acceleration/constants.py +++ b/plugins/framework/src/fms_acceleration/constants.py @@ -21,4 +21,4 @@ # and activated. # - hence the plugins that have model loaders should be on top of this list -PLUGINS = ["peft", "foak", "aadp", "moe", "odm"] +PLUGINS = ["peft", "foak", "aadp", "moe", "odm", "mcp"] diff --git a/plugins/mamba-cp/.DS_Store b/plugins/mamba-cp/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..19e14ad0de00558677b2018b401848ef82a27d94 GIT binary patch literal 6148 zcmeHKO-lnY5Pi`e3WD_LF@K;eUOg?P-o%S|w{B}&+L9saGV24gUgLZI z1$NlueP2J{HC26}-whtQvju%CWO&7f{S~=eax-!_w2xJRf-|Lh&NS|hJ{hBxXt3lu zTa6&rQb!B&=495aW5!wIP3=pv{?6|{<@f78PW^}Q{9FN7z!h)>j-~+bY$?MFLvLLH zSHKncQb5j!$SRm8Yz_73V53g};(%r??8{q1I5A$g#fvy7QHXOZvC-6xoacq6RQ;QYo)`&e&-{A135>obkgRd^y!x;Yz-xf*p*K74*?a#8&}{6 G3VZ-AH9E)u literal 0 HcmV?d00001 diff --git a/plugins/mamba-cp/.isort.cfg b/plugins/mamba-cp/.isort.cfg new file mode 100644 index 00000000..7d3762ec --- /dev/null +++ b/plugins/mamba-cp/.isort.cfg @@ -0,0 +1,10 @@ +[settings] +profile=black +from_first=true +import_heading_future=Future +import_heading_stdlib=Standard +import_heading_thirdparty=Third Party +import_heading_firstparty=First Party +import_heading_localfolder=Local +known_firstparty= +known_localfolder=tuning \ No newline at end of file diff --git a/plugins/mamba-cp/.pylintrc b/plugins/mamba-cp/.pylintrc new file mode 100644 index 00000000..4dc16dbc --- /dev/null +++ b/plugins/mamba-cp/.pylintrc @@ -0,0 +1,649 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,protobufs + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths=.*megablocks,.*khd + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.11 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=8 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1100 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + # Added messages + use-symbolic-message-instead, + invalid-name, + missing-class-docstring, + missing-module-docstring, + missing-function-docstring, + consider-using-f-string, + inconsistent-return-statements, + no-member, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + cyclic-import, + too-few-public-methods, + protected-access, + fixme, + logging-format-interpolation, + logging-too-many-args, + attribute-defined-outside-init, + abstract-method, + pointless-statement, + wrong-import-order, + duplicate-code, + unbalanced-tuple-unpacking, + unused-argument + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=yes + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the 'python-enchant' package. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io \ No newline at end of file diff --git a/plugins/mamba-cp/README.md b/plugins/mamba-cp/README.md new file mode 100644 index 00000000..68ba18fa --- /dev/null +++ b/plugins/mamba-cp/README.md @@ -0,0 +1 @@ +# Context Parallel for Mamba Kernels diff --git a/plugins/mamba-cp/configs/mcp.yaml b/plugins/mamba-cp/configs/mcp.yaml new file mode 100644 index 00000000..5f7f19c4 --- /dev/null +++ b/plugins/mamba-cp/configs/mcp.yaml @@ -0,0 +1,8 @@ +training: + odm: + odm: + update_interval: 1 # update every step + sampling_interval: 1 # sample category for every sample + reward_type: entropy # type of reward to use + gamma: 0.1 # MAB hyper-parameter + eta: 0.1 # MAB hyper-parameter diff --git a/plugins/mamba-cp/pyproject.toml b/plugins/mamba-cp/pyproject.toml new file mode 100644 index 00000000..804ef0a7 --- /dev/null +++ b/plugins/mamba-cp/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fms-acceleration-mcp" +version = '0.1.1.dev' +description = "FMS Acceleration plugin for context parallel for mamba kernels" +authors = [ + {name = "Mehant Kammakomati", email = "mehant.kammakomati2@ibm.com"}, +] +license = {text = "Apache-2.0"} +readme = "README.md" +requires-python = "~=3.11" +keywords = ['fms-hf-tuning', 'acceleration', 'mamba-cp'] +classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", +] + +dependencies = [] + +[tool.hatch.build.targets.wheel] +only-include = ["src/fms_acceleration_mcp"] + +[tool.hatch.build.targets.wheel.sources] +"src" = "" diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/.DS_Store b/plugins/mamba-cp/src/fms_acceleration_mcp/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..61382622f25f28f9e4986c5b19e6649b20f38e13 GIT binary patch literal 6148 zcmeHK!A`?441IwiP2#d6$NT`Q_=9R1H^gP)Jjyl(V%?^-l{og(dA7r}Lb<|(>?wJ+ zV>`){mNWpuxLMo)BLG7-!5~VHh`V>_ECmY&tx@gjr@N-A_dR7!(0aMYp`P7jaCrvTCFAX78+Yn zE9rYz|B6`^tIQAeXFt}TVapmNBfqZyKJkCFtJ9?3SLF;i1I~am@J9^r&K4P78hY;x zI0MeWfdM%mBAZ~Iur<`9gPlGBhy%J+sLNYIa$>?fVQa_{O0ZO-rH1sx2$oKNOmTU_ z*3i-sbY|8`Gk-o_g3eBV)Nq8{(0gaV8R#-_Zstht|6BfLdLQ}S6t|oKXW*YP5C)U! zWXw&ayY<`l)=;X5U(<>HA)thK=M4M; F10Q>8JNp0t literal 0 HcmV?d00001 diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py b/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py new file mode 100644 index 00000000..fcb6980a --- /dev/null +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py @@ -0,0 +1,20 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Local +from .callback import DataloaderSavingCallback +from .framework_plugin_odm import OnlineDataMixingAccelerationPlugin +from .odm import OnlineMixingDataset, Reward, compute_reward +from .patch import patch_hf_trainer_evaluate diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/callback.py b/plugins/mamba-cp/src/fms_acceleration_mcp/callback.py new file mode 100644 index 00000000..f7430f80 --- /dev/null +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/callback.py @@ -0,0 +1,38 @@ +# fms-hf-tuning patch +# Standard +from logging import getLogger +import os + +# Third Party +from transformers import TrainerCallback +import torch + +logger = getLogger(__name__) + + +class DataloaderSavingCallback(TrainerCallback): + def __init__(self, accelerator): + super().__init__() + self.accelerator = accelerator + + def on_save(self, args, state, control, **kwargs): + if not self.accelerator.is_main_process: + return + # Third Party + # pylint: disable=import-outside-toplevel + from torchdata.stateful_dataloader import StatefulDataLoader + + checkpoint_path = os.path.join( + args.output_dir, f"checkpoint-{state.global_step}" + ) + # It is assumed that one of the datasets would be stateful + # if stateful then it would be training dataset + for i, _ in enumerate(self.accelerator._dataloaders): + if isinstance( + self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader + ): + torch.save( + self.accelerator._dataloaders[i].state_dict(), + os.path.join(checkpoint_path, "odm_dl_state_dict.bin"), + ) + break diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py b/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py new file mode 100644 index 00000000..a5813216 --- /dev/null +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py @@ -0,0 +1,90 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import Dict, Tuple + +# Third Party +from fms_acceleration import AccelerationPlugin +from peft import LoraConfig +from transformers import TrainingArguments +import torch + +from .utils import patch_mamba_layers_with_cp_head + + +# pylint: disable=too-many-instance-attributes +class MCPAccelerationPlugin(AccelerationPlugin): + + def __init__(self, configurations: Dict[str, Dict]): + super().__init__(configurations) + self._mamba_cp_degree = self._check_config_and_maybe_check_values( + key="training.mamba.cp.degree", + default=None, + ) + self._cp_mamba_impl = self._check_config_and_maybe_check_values( + key="training.mamba.cp.mamba_impl", + default="allgather", + ) + self._cp_attn_impl = self._check_config_and_maybe_check_values( + key="training.mamba.cp.attn_impl", + default="ring", + ) + self._cp_mamba_recompute = self._check_config_and_maybe_check_values( + key="training.mamba.cp.mamba_recompute", + default=False, + ) + # data_config file should be there + @property + def requires_augmentation(self): + return True + + def augmentation( + self, + model, + train_args: TrainingArguments, + modifiable_args: Tuple[LoraConfig], + ): + if self._mamba_cp_degree != None: + rank = 0 + if torch.distributed.is_initialized(): + rank = torch.distributed.get_node_local_rank() + world_size = torch.distributed.get_world_size() + model_name = model.config.name_or_path + patch_mamba_layers_with_cp_head( + model=model, + checkpoint_name_or_path=model_name, + rank=rank, + cp_degree=self._mamba_cp_degree, + world_size=world_size, + cp_mamba_impl=self._cp_mamba_impl, + cp_attn_impl=self._cp_attn_impl, + cp_mamba_recompute=self._cp_mamba_recompute, + ) + return model, modifiable_args + + def get_callbacks_and_ready_for_train( + self, model: torch.nn.Module = None, accelerator=None + ): + callbacks = [] + return callbacks + + +# register +AccelerationPlugin.register_plugin( + MCPAccelerationPlugin, + configuration_and_paths=[ + "training.mamba.cp", + ], +) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/patch.py b/plugins/mamba-cp/src/fms_acceleration_mcp/patch.py new file mode 100644 index 00000000..a8d40e23 --- /dev/null +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/patch.py @@ -0,0 +1,233 @@ +# fms-hf-tuning patch +# Standard +from logging import getLogger +import os + +logger = getLogger(__name__) + + +def patch_hf_trainer_evaluate(): + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.model_patcher import patch_target_module + from transformers import Trainer + + Trainer._evaluate = _evaluate + Trainer._get_dataloader = _get_dataloader + Trainer.get_train_dataloader = get_train_dataloader + patch_target_module("transformers.trainer.Trainer", Trainer) + patch_target_module("transformers.trainer.skip_first_batches", skip_first_batches) + + +# code taken from transformers, modified and patches original function +def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): + # Standard + # pylint: disable=import-outside-toplevel + import time + + # Third Party + # pylint: disable=import-outside-toplevel + import torch + + metrics = None + if ( + self.model.ta_eval_steps + and self.state.global_step % self.model.ta_eval_steps == 0 + ): + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if ( + isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) + and not skip_scheduler + ): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + try: + self.lr_scheduler.step(metrics[metric_to_check]) + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is " + f"set to '{metric_to_check}', " + f"which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}." + f"Please ensure that the `compute_metrics` function returns a " + f"dictionary that includes '{metric_to_check}' or " + f"consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + if self.state.global_step % self.model.ta_update_interval == 0: + # prepare model + # code taken from def evaluation_loop from HF + model = self._wrap_model(self.model, training=False) + args = self.args + if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + or ( + self.is_fsdp_enabled + and self.accelerator.mixed_precision != "fp8" + and not self.args.torch_compile + ) + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + self.model_preparation_time = round(time.time() - start_time, 4) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, + # whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + if hasattr(model, "eval") and callable(model.eval): + model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() + # Do this before wrapping. + if args.past_index >= 0: + self._past = None + # prepare dataloader + self.train_dataset.update_sampling_weights(model, self.accelerator, self.state) + + return metrics + + +# code taken from transformers, modified and patches original function +def _get_dataloader( + self, + dataset, + description, + batch_size, + sampler_fn=None, + is_training=False, + dataloader_key=None, +): + """Create a [`~torch.utils.data.DataLoader`] from the given dataset.""" + # Standard + # pylint: disable=import-outside-toplevel + from functools import partial + + # Third Party + # pylint: disable=import-outside-toplevel + from torch.utils.data import DataLoader + from torchdata.stateful_dataloader import StatefulDataLoader + from transformers import is_datasets_available + from transformers.trainer_utils import seed_worker + import torch + + if is_datasets_available(): + # Third Party + # pylint: disable=import-outside-toplevel + import datasets + + data_collator = self.data_collator + if is_datasets_available() and isinstance(dataset, datasets.Dataset): + dataset = self._remove_unused_columns(dataset, description=description) + else: + data_collator = self._get_collator_with_removed_columns( + self.data_collator, description=description + ) + + dataloader_params = { + "batch_size": batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(dataset, torch.utils.data.IterableDataset): + if sampler_fn is not None: + dataloader_params["sampler"] = sampler_fn(dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if is_training: + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + if is_training: + self.accelerator.dataloader_config.use_stateful_dataloader = True + dataloader = self.accelerator.prepare( + StatefulDataLoader(dataset, **dataloader_params) + ) + else: + dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params)) + + # Store the prepared dataloader for subsequent evaluations if using persistent workers. + if dataloader_key is not None and self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = dataloader + else: + self._eval_dataloaders = {dataloader_key: dataloader} + + return dataloader + + +# code taken from transformers, modified and patches original function +def get_train_dataloader(self): + # Third Party + # pylint: disable=import-outside-toplevel + from torchdata.stateful_dataloader import StatefulDataLoader + from transformers.trainer_utils import get_last_checkpoint + import torch + + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + dataloader = self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size, + sampler_fn=self._get_train_sampler, + is_training=True, + ) + resume_from_checkpoint = self.model.resume_from_checkpoint + if resume_from_checkpoint: + # code taken from transformers and modified + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) + if resume_from_checkpoint is None: + raise ValueError( + f"No valid checkpoint found in output directory ({self.args.output_dir})" + ) + self.model.resume_from_checkpoint = resume_from_checkpoint + + # load state to the dataloader + dataloader_state_dict_name = "odm_dl_state_dict.bin" + output_dataloader_state_dict_file = os.path.join( + resume_from_checkpoint, dataloader_state_dict_name + ) + for i, _ in enumerate(self.accelerator._dataloaders): + if isinstance( + self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader + ): + self.accelerator._dataloaders[i].load_state_dict( + torch.load(output_dataloader_state_dict_file) + ) + break + return dataloader + + +# code taken from transformers, modified and patches original function +def skip_first_batches(dataloader, num_batches=0): + return dataloader diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py new file mode 100644 index 00000000..c2caf6e0 --- /dev/null +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py @@ -0,0 +1,17 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .utils import patch_mamba_layers_with_cp_head + diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py new file mode 100644 index 00000000..c660c112 --- /dev/null +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -0,0 +1,89 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +# pylint: disable=import-error +from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh +import torch +from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 + +key_ep = "cp" +key_rep = "dp_shard" + +def hf_config_ssm_config(hf_config) -> Dict: + config_ssm = {} + config_ssm["d_model"] = hf_config.hidden_size + config_ssm["n_layer"] = hf_config.num_hidden_layers + config_ssm["tie_embeddings"] = hf_config.tie_word_embeddings + config_ssm["d_state"] = 128 + config_ssm["ngroups"] = hf_config.mamba_n_groups + config_ssm["rmsnorm"] = True + config_ssm["chunk_size"] = hf_config.mamba_chunk_size + config_ssm["conv_bias"] = hf_config.mamba_conv_bias + config_ssm["d_conv"] = hf_config.mamba_d_conv + return config_ssm + + +def patch_mamba_layers_with_cp_head( + model, + checkpoint_name_or_path, + rank, + cp_degree, + world_size, + cp_mamba_impl, + cp_attn_impl, + cp_mamba_recompute +): + config_ssm = hf_config_ssm_config(model.config) + device = torch.device(f"cuda:{rank}") + if is_fsdp_enabled(): + device = torch.device("cpu") + try: + from mamba_ssm.modules.mamba2_cp import Mamba2CP + except ImportError: + ValueError( + "Mamba2CP is required to enable context parallelism for mamba layers" + ) + rep_size = world_size // cp_degree + + if cp_degree == 1: + raise ValueError("CP degree can't be one") + elif rep_size == 1: + device_mesh = init_device_mesh( + "cuda", + (cp_degree,), + mesh_dim_names=(key_ep,), + ) + else: + device_mesh = init_device_mesh( + "cuda", + (rep_size, cp_degree), + mesh_dim_names=(key_rep, key_ep), + ) + + cp_args = { + "cp_mesh": device_mesh[key_ep], + "cp_mamba_impl": cp_mamba_impl, + "cp_attn_impl": cp_attn_impl, + "cp_mamba_recompute": cp_mamba_recompute, + } + + with torch.no_grad(): + for layer in model.layers: + mamba_layer = Mamba2CP(**config_ssm, **cp_args) + mamba_layer.load_state_dict(layer.mamba.state_dict()) + setattr(layer, "mamba", mamba_layer) + layer.to(device) + + if hasattr(model, "tie_weights"): + model.tie_weights() diff --git a/plugins/mamba-cp/tox.ini b/plugins/mamba-cp/tox.ini new file mode 100644 index 00000000..1a21a899 --- /dev/null +++ b/plugins/mamba-cp/tox.ini @@ -0,0 +1,50 @@ +[tox] +envlist = py, lint + +[testenv] +deps = + pytest>=7 + importlib-metadata + -e {toxinidir} +skip_install = true +commands = + + # install the dependencies here to ensure + # the order + pip install -e {toxinidir}/../framework + pytest {posargs:tests} + +[testenv:lint] +description = run linters +skip_install = false +deps = + -e {toxinidir}/../framework + pylint>=2.16.2,<=3.1.0 + datasets() +commands = + pylint src tests +allowlist_externals = pylint + +[testenv:fmt] +description = format +skip_install = true +deps = + black>=22.12 + isort>=5.11 +commands = + black {posargs:.} + isort {posargs:.} + +[testenv:build] +description = build wheel +deps = + build +commands = python -m build -w +skip_install = True + +[testenv:twinecheck] +description = check wheel +deps = + twine +commands = twine check dist/* +skip_install = True \ No newline at end of file From f479ef5b594e555edbb082ee65a6d4534e1f94b2 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 00:00:07 +0530 Subject: [PATCH 02/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_mcp/__init__.py | 5 +- .../src/fms_acceleration_mcp/callback.py | 38 --- .../src/fms_acceleration_mcp/patch.py | 233 ------------------ 3 files changed, 1 insertion(+), 275 deletions(-) delete mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/callback.py delete mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/patch.py diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py b/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py index fcb6980a..1e80026a 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/__init__.py @@ -14,7 +14,4 @@ # Local -from .callback import DataloaderSavingCallback -from .framework_plugin_odm import OnlineDataMixingAccelerationPlugin -from .odm import OnlineMixingDataset, Reward, compute_reward -from .patch import patch_hf_trainer_evaluate +from .framework_plugin_mcp import MCPAccelerationPlugin diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/callback.py b/plugins/mamba-cp/src/fms_acceleration_mcp/callback.py deleted file mode 100644 index f7430f80..00000000 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/callback.py +++ /dev/null @@ -1,38 +0,0 @@ -# fms-hf-tuning patch -# Standard -from logging import getLogger -import os - -# Third Party -from transformers import TrainerCallback -import torch - -logger = getLogger(__name__) - - -class DataloaderSavingCallback(TrainerCallback): - def __init__(self, accelerator): - super().__init__() - self.accelerator = accelerator - - def on_save(self, args, state, control, **kwargs): - if not self.accelerator.is_main_process: - return - # Third Party - # pylint: disable=import-outside-toplevel - from torchdata.stateful_dataloader import StatefulDataLoader - - checkpoint_path = os.path.join( - args.output_dir, f"checkpoint-{state.global_step}" - ) - # It is assumed that one of the datasets would be stateful - # if stateful then it would be training dataset - for i, _ in enumerate(self.accelerator._dataloaders): - if isinstance( - self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader - ): - torch.save( - self.accelerator._dataloaders[i].state_dict(), - os.path.join(checkpoint_path, "odm_dl_state_dict.bin"), - ) - break diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/patch.py b/plugins/mamba-cp/src/fms_acceleration_mcp/patch.py deleted file mode 100644 index a8d40e23..00000000 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/patch.py +++ /dev/null @@ -1,233 +0,0 @@ -# fms-hf-tuning patch -# Standard -from logging import getLogger -import os - -logger = getLogger(__name__) - - -def patch_hf_trainer_evaluate(): - # Third Party - # pylint: disable=import-outside-toplevel - from fms_acceleration.model_patcher import patch_target_module - from transformers import Trainer - - Trainer._evaluate = _evaluate - Trainer._get_dataloader = _get_dataloader - Trainer.get_train_dataloader = get_train_dataloader - patch_target_module("transformers.trainer.Trainer", Trainer) - patch_target_module("transformers.trainer.skip_first_batches", skip_first_batches) - - -# code taken from transformers, modified and patches original function -def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): - # Standard - # pylint: disable=import-outside-toplevel - import time - - # Third Party - # pylint: disable=import-outside-toplevel - import torch - - metrics = None - if ( - self.model.ta_eval_steps - and self.state.global_step % self.model.ta_eval_steps == 0 - ): - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) - self._report_to_hp_search(trial, self.state.global_step, metrics) - - # Run delayed LR scheduler now that metrics are populated - if ( - isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) - and not skip_scheduler - ): - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - try: - self.lr_scheduler.step(metrics[metric_to_check]) - except KeyError as exc: - raise KeyError( - f"The `metric_for_best_model` training argument is " - f"set to '{metric_to_check}', " - f"which is not found in the evaluation metrics. " - f"The available evaluation metrics are: {list(metrics.keys())}." - f"Please ensure that the `compute_metrics` function returns a " - f"dictionary that includes '{metric_to_check}' or " - f"consider changing the `metric_for_best_model` via the TrainingArguments." - ) from exc - - if self.state.global_step % self.model.ta_update_interval == 0: - # prepare model - # code taken from def evaluation_loop from HF - model = self._wrap_model(self.model, training=False) - args = self.args - if len(self.accelerator._models) == 0 and model is self.model: - start_time = time.time() - model = ( - self.accelerator.prepare(model) - if self.is_deepspeed_enabled - or ( - self.is_fsdp_enabled - and self.accelerator.mixed_precision != "fp8" - and not self.args.torch_compile - ) - else self.accelerator.prepare_model(model, evaluation_mode=True) - ) - self.model_preparation_time = round(time.time() - start_time, 4) - - if self.is_fsdp_enabled: - self.model = model - - # for the rest of this function `model` is the outside model, - # whether it was wrapped or not - if model is not self.model: - self.model_wrapped = model - - # backward compatibility - if self.is_deepspeed_enabled: - self.deepspeed = self.model_wrapped - - # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called - # while ``train`` is running, cast it to the right dtype first and then put on device - if not self.is_in_train: - if args.fp16_full_eval: - model = model.to(dtype=torch.float16, device=args.device) - elif args.bf16_full_eval: - model = model.to(dtype=torch.bfloat16, device=args.device) - - if hasattr(model, "eval") and callable(model.eval): - model.eval() - if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): - self.optimizer.eval() - # Do this before wrapping. - if args.past_index >= 0: - self._past = None - # prepare dataloader - self.train_dataset.update_sampling_weights(model, self.accelerator, self.state) - - return metrics - - -# code taken from transformers, modified and patches original function -def _get_dataloader( - self, - dataset, - description, - batch_size, - sampler_fn=None, - is_training=False, - dataloader_key=None, -): - """Create a [`~torch.utils.data.DataLoader`] from the given dataset.""" - # Standard - # pylint: disable=import-outside-toplevel - from functools import partial - - # Third Party - # pylint: disable=import-outside-toplevel - from torch.utils.data import DataLoader - from torchdata.stateful_dataloader import StatefulDataLoader - from transformers import is_datasets_available - from transformers.trainer_utils import seed_worker - import torch - - if is_datasets_available(): - # Third Party - # pylint: disable=import-outside-toplevel - import datasets - - data_collator = self.data_collator - if is_datasets_available() and isinstance(dataset, datasets.Dataset): - dataset = self._remove_unused_columns(dataset, description=description) - else: - data_collator = self._get_collator_with_removed_columns( - self.data_collator, description=description - ) - - dataloader_params = { - "batch_size": batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(dataset, torch.utils.data.IterableDataset): - if sampler_fn is not None: - dataloader_params["sampler"] = sampler_fn(dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - if is_training: - dataloader_params["worker_init_fn"] = partial( - seed_worker, - num_workers=self.args.dataloader_num_workers, - rank=self.args.process_index, - ) - if is_training: - self.accelerator.dataloader_config.use_stateful_dataloader = True - dataloader = self.accelerator.prepare( - StatefulDataLoader(dataset, **dataloader_params) - ) - else: - dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params)) - - # Store the prepared dataloader for subsequent evaluations if using persistent workers. - if dataloader_key is not None and self.args.dataloader_persistent_workers: - if hasattr(self, "_eval_dataloaders"): - self._eval_dataloaders[dataloader_key] = dataloader - else: - self._eval_dataloaders = {dataloader_key: dataloader} - - return dataloader - - -# code taken from transformers, modified and patches original function -def get_train_dataloader(self): - # Third Party - # pylint: disable=import-outside-toplevel - from torchdata.stateful_dataloader import StatefulDataLoader - from transformers.trainer_utils import get_last_checkpoint - import torch - - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - dataloader = self._get_dataloader( - dataset=self.train_dataset, - description="Training", - batch_size=self._train_batch_size, - sampler_fn=self._get_train_sampler, - is_training=True, - ) - resume_from_checkpoint = self.model.resume_from_checkpoint - if resume_from_checkpoint: - # code taken from transformers and modified - if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: - resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) - if resume_from_checkpoint is None: - raise ValueError( - f"No valid checkpoint found in output directory ({self.args.output_dir})" - ) - self.model.resume_from_checkpoint = resume_from_checkpoint - - # load state to the dataloader - dataloader_state_dict_name = "odm_dl_state_dict.bin" - output_dataloader_state_dict_file = os.path.join( - resume_from_checkpoint, dataloader_state_dict_name - ) - for i, _ in enumerate(self.accelerator._dataloaders): - if isinstance( - self.accelerator._dataloaders[i].base_dataloader, StatefulDataLoader - ): - self.accelerator._dataloaders[i].load_state_dict( - torch.load(output_dataloader_state_dict_file) - ) - break - return dataloader - - -# code taken from transformers, modified and patches original function -def skip_first_batches(dataloader, num_batches=0): - return dataloader From b1e7e21c94f25803e1b2f0f6e88f6d1b4dba2bfc Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:17:54 +0530 Subject: [PATCH 03/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index c660c112..40324d78 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -16,6 +16,7 @@ from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh import torch from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 +from tqdm import tqdm key_ep = "cp" key_rep = "dp_shard" @@ -79,7 +80,7 @@ def patch_mamba_layers_with_cp_head( } with torch.no_grad(): - for layer in model.layers: + for layer in tqdm(model.layers, desc="Swapping mamba layers", total=len(model.layers)): mamba_layer = Mamba2CP(**config_ssm, **cp_args) mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) From 7185e25eda3318497a372b5c3a94c7b3620af704 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:35:57 +0530 Subject: [PATCH 04/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 40324d78..5bc0faf3 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -80,7 +80,7 @@ def patch_mamba_layers_with_cp_head( } with torch.no_grad(): - for layer in tqdm(model.layers, desc="Swapping mamba layers", total=len(model.layers)): + for layer in tqdm(model.model.layers, desc="Swapping mamba layers"): mamba_layer = Mamba2CP(**config_ssm, **cp_args) mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) From 6ea7165060ee5ac8b7f10d11a3d46fba4eed41d2 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:38:22 +0530 Subject: [PATCH 05/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 5bc0faf3..be92c087 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -24,7 +24,6 @@ def hf_config_ssm_config(hf_config) -> Dict: config_ssm = {} config_ssm["d_model"] = hf_config.hidden_size - config_ssm["n_layer"] = hf_config.num_hidden_layers config_ssm["tie_embeddings"] = hf_config.tie_word_embeddings config_ssm["d_state"] = 128 config_ssm["ngroups"] = hf_config.mamba_n_groups From 0eb1bfd5e37e22902ba241134679442186ec91f4 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:40:00 +0530 Subject: [PATCH 06/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index be92c087..f9c392f1 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -24,7 +24,6 @@ def hf_config_ssm_config(hf_config) -> Dict: config_ssm = {} config_ssm["d_model"] = hf_config.hidden_size - config_ssm["tie_embeddings"] = hf_config.tie_word_embeddings config_ssm["d_state"] = 128 config_ssm["ngroups"] = hf_config.mamba_n_groups config_ssm["rmsnorm"] = True From 6f2aaa4a4637af2b2b273919e0c299c78d94a6db Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:54:44 +0530 Subject: [PATCH 07/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_mcp/framework_plugin_mcp.py | 5 ----- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 2 -- 2 files changed, 7 deletions(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py b/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py index a5813216..95decd76 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py @@ -37,10 +37,6 @@ def __init__(self, configurations: Dict[str, Dict]): key="training.mamba.cp.mamba_impl", default="allgather", ) - self._cp_attn_impl = self._check_config_and_maybe_check_values( - key="training.mamba.cp.attn_impl", - default="ring", - ) self._cp_mamba_recompute = self._check_config_and_maybe_check_values( key="training.mamba.cp.mamba_recompute", default=False, @@ -69,7 +65,6 @@ def augmentation( cp_degree=self._mamba_cp_degree, world_size=world_size, cp_mamba_impl=self._cp_mamba_impl, - cp_attn_impl=self._cp_attn_impl, cp_mamba_recompute=self._cp_mamba_recompute, ) return model, modifiable_args diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index f9c392f1..27dc963e 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -40,7 +40,6 @@ def patch_mamba_layers_with_cp_head( cp_degree, world_size, cp_mamba_impl, - cp_attn_impl, cp_mamba_recompute ): config_ssm = hf_config_ssm_config(model.config) @@ -73,7 +72,6 @@ def patch_mamba_layers_with_cp_head( cp_args = { "cp_mesh": device_mesh[key_ep], "cp_mamba_impl": cp_mamba_impl, - "cp_attn_impl": cp_attn_impl, "cp_mamba_recompute": cp_mamba_recompute, } From a8832be0288bf3e718925590342ebd4640a99179 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 17:59:55 +0530 Subject: [PATCH 08/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 27dc963e..7ec10944 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -77,10 +77,12 @@ def patch_mamba_layers_with_cp_head( with torch.no_grad(): for layer in tqdm(model.model.layers, desc="Swapping mamba layers"): - mamba_layer = Mamba2CP(**config_ssm, **cp_args) - mamba_layer.load_state_dict(layer.mamba.state_dict()) - setattr(layer, "mamba", mamba_layer) - layer.to(device) + if hasattr(layer, "mamba") and layer.mamba is not None: + print("mamba layer found") + mamba_layer = Mamba2CP(**config_ssm, **cp_args) + mamba_layer.load_state_dict(layer.mamba.state_dict()) + setattr(layer, "mamba", mamba_layer) + layer.to(device) if hasattr(model, "tie_weights"): model.tie_weights() From 893f4cd6229d8df05fdb8d97aba2464e9bd9fd55 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 14 Nov 2025 18:05:27 +0530 Subject: [PATCH 09/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 7ec10944..b5f78478 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -80,9 +80,11 @@ def patch_mamba_layers_with_cp_head( if hasattr(layer, "mamba") and layer.mamba is not None: print("mamba layer found") mamba_layer = Mamba2CP(**config_ssm, **cp_args) + dtype = layer.mamba.dtype + device = layer.mamba.device mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) - layer.to(device) + layer.to(dtype).to(device) if hasattr(model, "tie_weights"): model.tie_weights() From 1e9433b180780e997ddefdd2869fe62e0cc0d56b Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Mon, 17 Nov 2025 13:08:12 +0900 Subject: [PATCH 10/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index b5f78478..23535eb5 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -76,11 +76,11 @@ def patch_mamba_layers_with_cp_head( } with torch.no_grad(): + dtype = model.dtype for layer in tqdm(model.model.layers, desc="Swapping mamba layers"): if hasattr(layer, "mamba") and layer.mamba is not None: print("mamba layer found") mamba_layer = Mamba2CP(**config_ssm, **cp_args) - dtype = layer.mamba.dtype device = layer.mamba.device mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) From ec60dabdb91e6bf5f650225744aa1e7fa4865f36 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Mon, 17 Nov 2025 13:11:23 +0900 Subject: [PATCH 11/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 23535eb5..30598be5 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -77,11 +77,11 @@ def patch_mamba_layers_with_cp_head( with torch.no_grad(): dtype = model.dtype + device = model.device for layer in tqdm(model.model.layers, desc="Swapping mamba layers"): if hasattr(layer, "mamba") and layer.mamba is not None: print("mamba layer found") mamba_layer = Mamba2CP(**config_ssm, **cp_args) - device = layer.mamba.device mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) layer.to(dtype).to(device) From ffe2f175db6c269e74a4d24bac7cbb7507f29f8a Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 19 Nov 2025 17:17:25 +0900 Subject: [PATCH 12/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_mcp/utils/utils.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 30598be5..d9d31ba3 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict + # pylint: disable=import-error from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh import torch from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 from tqdm import tqdm +try: + from mamba_ssm.modules.mamba2_cp import Mamba2CP +except ImportError: + ValueError("Mamba2CP is required to enable context parallelism for mamba layers") + key_ep = "cp" key_rep = "dp_shard" + def hf_config_ssm_config(hf_config) -> Dict: config_ssm = {} config_ssm["d_model"] = hf_config.hidden_size @@ -33,6 +40,12 @@ def hf_config_ssm_config(hf_config) -> Dict: return config_ssm +class Mamba2CPHF(Mamba2CP): + def forward( + self, hidden_states, cache_params, cache_position, attention_mask, seq_idx + ): + return super().forward(u=hidden_states, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None) + def patch_mamba_layers_with_cp_head( model, checkpoint_name_or_path, @@ -40,18 +53,12 @@ def patch_mamba_layers_with_cp_head( cp_degree, world_size, cp_mamba_impl, - cp_mamba_recompute + cp_mamba_recompute, ): config_ssm = hf_config_ssm_config(model.config) device = torch.device(f"cuda:{rank}") if is_fsdp_enabled(): device = torch.device("cpu") - try: - from mamba_ssm.modules.mamba2_cp import Mamba2CP - except ImportError: - ValueError( - "Mamba2CP is required to enable context parallelism for mamba layers" - ) rep_size = world_size // cp_degree if cp_degree == 1: @@ -74,14 +81,14 @@ def patch_mamba_layers_with_cp_head( "cp_mamba_impl": cp_mamba_impl, "cp_mamba_recompute": cp_mamba_recompute, } - + with torch.no_grad(): dtype = model.dtype device = model.device for layer in tqdm(model.model.layers, desc="Swapping mamba layers"): if hasattr(layer, "mamba") and layer.mamba is not None: print("mamba layer found") - mamba_layer = Mamba2CP(**config_ssm, **cp_args) + mamba_layer = Mamba2CPHF(**config_ssm, **cp_args) mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) layer.to(dtype).to(device) From a7365ea342bdbfff4ad73e3642e53143d321fb9b Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 19 Nov 2025 17:38:00 +0900 Subject: [PATCH 13/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index d9d31ba3..14ef8015 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -42,7 +42,7 @@ def hf_config_ssm_config(hf_config) -> Dict: class Mamba2CPHF(Mamba2CP): def forward( - self, hidden_states, cache_params, cache_position, attention_mask, seq_idx + self, hidden_states, cache_params, cache_position, attention_mask, seq_idx, **kwargs ): return super().forward(u=hidden_states, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None) From e9fcfa7be15705f5b19f0133aef0868fa297e4a1 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 19 Nov 2025 17:41:09 +0900 Subject: [PATCH 14/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 14ef8015..f0579fb3 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -42,7 +42,7 @@ def hf_config_ssm_config(hf_config) -> Dict: class Mamba2CPHF(Mamba2CP): def forward( - self, hidden_states, cache_params, cache_position, attention_mask, seq_idx, **kwargs + self, hidden_states, cache_params=None, cache_position=None, attention_mask=None, seq_idx=None, **kwargs ): return super().forward(u=hidden_states, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None) From 3fc0a349fe5992eaa35ce2c10902fdb7fcef9f79 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Mon, 24 Nov 2025 12:47:13 +0530 Subject: [PATCH 15/27] feat: add support for mamba cp Signed-off-by: Mehant Kammakomati --- .../framework_plugin_mcp.py | 4 ++- .../fms_acceleration_mcp/utils/__init__.py | 1 - .../src/fms_acceleration_mcp/utils/utils.py | 34 +++++++++++++------ 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py b/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py index 95decd76..9850cb51 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py @@ -21,6 +21,7 @@ from transformers import TrainingArguments import torch +# Local from .utils import patch_mamba_layers_with_cp_head @@ -41,6 +42,7 @@ def __init__(self, configurations: Dict[str, Dict]): key="training.mamba.cp.mamba_recompute", default=False, ) + # data_config file should be there @property def requires_augmentation(self): @@ -52,7 +54,7 @@ def augmentation( train_args: TrainingArguments, modifiable_args: Tuple[LoraConfig], ): - if self._mamba_cp_degree != None: + if self._mamba_cp_degree is not None: rank = 0 if torch.distributed.is_initialized(): rank = torch.distributed.get_node_local_rank() diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py index c2caf6e0..76d835af 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py @@ -14,4 +14,3 @@ # Local from .utils import patch_mamba_layers_with_cp_head - diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index f0579fb3..a5dca515 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -11,18 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Standard from typing import Dict +# Third Party +from mamba_ssm.modules.mamba2_cp import Mamba2CP + # pylint: disable=import-error -from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh -import torch -from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 +from torch.distributed._tensor.device_mesh import init_device_mesh from tqdm import tqdm - -try: - from mamba_ssm.modules.mamba2_cp import Mamba2CP -except ImportError: - ValueError("Mamba2CP is required to enable context parallelism for mamba layers") +from transformers.modeling_utils import is_fsdp_enabled +import torch key_ep = "cp" key_rep = "dp_shard" @@ -42,9 +41,22 @@ def hf_config_ssm_config(hf_config) -> Dict: class Mamba2CPHF(Mamba2CP): def forward( - self, hidden_states, cache_params=None, cache_position=None, attention_mask=None, seq_idx=None, **kwargs + self, + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + seq_idx=None, + **kwargs, ): - return super().forward(u=hidden_states, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None) + return super().forward( + u=hidden_states, + seqlen=None, + seq_idx=None, + cu_seqlens=None, + inference_params=None, + ) + def patch_mamba_layers_with_cp_head( model, @@ -63,7 +75,7 @@ def patch_mamba_layers_with_cp_head( if cp_degree == 1: raise ValueError("CP degree can't be one") - elif rep_size == 1: + if rep_size == 1: device_mesh = init_device_mesh( "cuda", (cp_degree,), From 66af3f6217f541f476d47dcbbb257d13e8fbb85b Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 26 Nov 2025 23:55:50 +0530 Subject: [PATCH 16/27] debug Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index a5dca515..75633dd7 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -67,6 +67,11 @@ def patch_mamba_layers_with_cp_head( cp_mamba_impl, cp_mamba_recompute, ): + # to avoid rechunking/sharding of the buffers + # ideally this is not optimal + from torch.distributed.tensor.experimental._attention import _cp_options + _cp_options.enable_load_balance = False + config_ssm = hf_config_ssm_config(model.config) device = torch.device(f"cuda:{rank}") if is_fsdp_enabled(): From ffea60fb2489151f4cc742531e4fa7cb1bbde4ed Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 27 Nov 2025 00:05:57 +0530 Subject: [PATCH 17/27] debug Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_mcp/utils/utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 75633dd7..2b5f1ad5 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -23,7 +23,13 @@ from transformers.modeling_utils import is_fsdp_enabled import torch -key_ep = "cp" +# to avoid rechunking/sharding of the buffers +# ideally this is not optimal +from torch.distributed.tensor.experimental._attention import _cp_options +_cp_options.enable_load_balance = False + + +key_cp = "cp" key_rep = "dp_shard" @@ -67,10 +73,6 @@ def patch_mamba_layers_with_cp_head( cp_mamba_impl, cp_mamba_recompute, ): - # to avoid rechunking/sharding of the buffers - # ideally this is not optimal - from torch.distributed.tensor.experimental._attention import _cp_options - _cp_options.enable_load_balance = False config_ssm = hf_config_ssm_config(model.config) device = torch.device(f"cuda:{rank}") @@ -84,17 +86,17 @@ def patch_mamba_layers_with_cp_head( device_mesh = init_device_mesh( "cuda", (cp_degree,), - mesh_dim_names=(key_ep,), + mesh_dim_names=(key_cp,), ) else: device_mesh = init_device_mesh( "cuda", (rep_size, cp_degree), - mesh_dim_names=(key_rep, key_ep), + mesh_dim_names=(key_rep, key_cp), ) cp_args = { - "cp_mesh": device_mesh[key_ep], + "cp_mesh": device_mesh[key_cp], "cp_mamba_impl": cp_mamba_impl, "cp_mamba_recompute": cp_mamba_recompute, } From bfc90a99471b22cc998e533285fb2b1a78ed8e7c Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 27 Nov 2025 18:09:13 +0530 Subject: [PATCH 18/27] fix: remove print stmts Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 2b5f1ad5..76935ba6 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -106,7 +106,6 @@ def patch_mamba_layers_with_cp_head( device = model.device for layer in tqdm(model.model.layers, desc="Swapping mamba layers"): if hasattr(layer, "mamba") and layer.mamba is not None: - print("mamba layer found") mamba_layer = Mamba2CPHF(**config_ssm, **cp_args) mamba_layer.load_state_dict(layer.mamba.state_dict()) setattr(layer, "mamba", mamba_layer) From e4e470e9c5aceb8d1ee9eaa1a8089dcd2e3610d2 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 11:17:15 +0530 Subject: [PATCH 19/27] docs: add docs Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/README.md | 15 ++++++++ .../src/fms_acceleration_mcp/utils/utils.py | 34 +++++++++++++++---- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/plugins/mamba-cp/README.md b/plugins/mamba-cp/README.md index 68ba18fa..55a09955 100644 --- a/plugins/mamba-cp/README.md +++ b/plugins/mamba-cp/README.md @@ -1 +1,16 @@ # Context Parallel for Mamba Kernels + +This library contains plugin for applying context parallelism for mamba module (mamba_ssm). + +## Plugins + +Plugin | Description | Depends | Loading | Augmentation | Callbacks +--|--|--|--|--|-- +[mcp](./src/fms_acceleration_mcp/framework_plugin_mcp.py) | context parallel for mamba | [custom mamba cp implementation](https://github.com/garrett361/mamba/tree/mamba-cp) | ✅ | ✅ | ✅ + +## Mamba CP Implementation + +Context parallel implementation is taken from a custom [mamba_ssm repo](https://github.com/garrett361/mamba/tree/mamba-cp) with cp implemenation. Thus, its required this repo is installed to use this plugin. + +## Known Issues +1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask. diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 76935ba6..092f470d 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -15,7 +15,13 @@ from typing import Dict # Third Party -from mamba_ssm.modules.mamba2_cp import Mamba2CP +try: + from mamba_ssm.modules.mamba2_cp import Mamba2CP +except ImportError: + raise ValueError("custom mamba_ssm package installation is needed" + "install from https://github.com/garrett361/mamba/tree/mamba-cp" + ) +from accelerate.logging import get_logger # pylint: disable=import-error from torch.distributed._tensor.device_mesh import init_device_mesh @@ -25,15 +31,20 @@ # to avoid rechunking/sharding of the buffers # ideally this is not optimal +# this is done to make self attention cp compatible with mamba cp from torch.distributed.tensor.experimental._attention import _cp_options _cp_options.enable_load_balance = False +logger = get_logger(__name__) +# the same keys are used in accelerate +# therefore we choose these to be in sync and cross leverage. key_cp = "cp" key_rep = "dp_shard" - -def hf_config_ssm_config(hf_config) -> Dict: +# extract ssm config from hf config to be used +# while swapping the mamba modules +def get_ssmconfig_from_hfconfig(hf_config) -> Dict: config_ssm = {} config_ssm["d_model"] = hf_config.hidden_size config_ssm["d_state"] = 128 @@ -45,6 +56,7 @@ def hf_config_ssm_config(hf_config) -> Dict: return config_ssm +# to patch input arguments between mamba cp module and standard hf mamba module class Mamba2CPHF(Mamba2CP): def forward( self, @@ -63,7 +75,10 @@ def forward( inference_params=None, ) - +# patches each mamba module with mamba cp module +# mamba cp module's weights are exactly same as hf mamba module +# so we reuse the state dict and the same does not need special handling +# while checkpointing. def patch_mamba_layers_with_cp_head( model, checkpoint_name_or_path, @@ -74,12 +89,18 @@ def patch_mamba_layers_with_cp_head( cp_mamba_recompute, ): - config_ssm = hf_config_ssm_config(model.config) + config_ssm = get_ssmconfig_from_hfconfig(model.config) device = torch.device(f"cuda:{rank}") if is_fsdp_enabled(): device = torch.device("cpu") rep_size = world_size // cp_degree - + + # auto infer ddp and cp ranks + # does not work on other combination of parallelisms + logger.warning( + "Mamba CP is only meant for parallelism combinations having DP and CP" + "other combinations can lead to unexpected behaviour" + ) if cp_degree == 1: raise ValueError("CP degree can't be one") if rep_size == 1: @@ -100,7 +121,6 @@ def patch_mamba_layers_with_cp_head( "cp_mamba_impl": cp_mamba_impl, "cp_mamba_recompute": cp_mamba_recompute, } - with torch.no_grad(): dtype = model.dtype device = model.device From ec667127bc3b578052af7923eb4e8f6cf621cc6c Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 11:34:34 +0530 Subject: [PATCH 20/27] docs: add docs Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/mamba-cp/README.md b/plugins/mamba-cp/README.md index 55a09955..1a0c2127 100644 --- a/plugins/mamba-cp/README.md +++ b/plugins/mamba-cp/README.md @@ -14,3 +14,4 @@ Context parallel implementation is taken from a custom [mamba_ssm repo](https:// ## Known Issues 1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask. +2. Padding free and flash attention are not supported. \ No newline at end of file From 74202e120d7d89f344b5a6a24ff62a9aace8ad5d Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 11:53:38 +0530 Subject: [PATCH 21/27] nit: lint and fmt Signed-off-by: Mehant Kammakomati --- .github/workflows/build-and-publish.yml | 1 + .github/workflows/format.yml | 1 + .../src/fms_acceleration_mcp/utils/utils.py | 25 +++++++++++-------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build-and-publish.yml b/.github/workflows/build-and-publish.yml index c9080c55..af728aa2 100644 --- a/.github/workflows/build-and-publish.yml +++ b/.github/workflows/build-and-publish.yml @@ -17,6 +17,7 @@ jobs: - "attention-and-distributed-packing" - "accelerated-moe" - "online-data-mixing" + - "mamba-cp" permissions: id-token: write # IMPORTANT: this permission is mandatory for trusted publishing diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 87efabe7..dc505645 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -32,6 +32,7 @@ jobs: - "attention-and-distributed-packing" - "accelerated-moe" - "online-data-mixing" + - "mamba-cp" steps: - name: Delete huge unnecessary tools folder diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py index 092f470d..7fdf6c64 100644 --- a/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py +++ b/plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py @@ -14,25 +14,28 @@ # Standard from typing import Dict -# Third Party try: + # Third Party from mamba_ssm.modules.mamba2_cp import Mamba2CP -except ImportError: - raise ValueError("custom mamba_ssm package installation is needed" - "install from https://github.com/garrett361/mamba/tree/mamba-cp" - ) +except ImportError as exc: + raise ValueError( + "custom mamba_ssm package installation is needed" + "install from https://github.com/garrett361/mamba/tree/mamba-cp" + ) from exc +# Third Party from accelerate.logging import get_logger # pylint: disable=import-error from torch.distributed._tensor.device_mesh import init_device_mesh -from tqdm import tqdm -from transformers.modeling_utils import is_fsdp_enabled -import torch # to avoid rechunking/sharding of the buffers # ideally this is not optimal # this is done to make self attention cp compatible with mamba cp from torch.distributed.tensor.experimental._attention import _cp_options +from tqdm import tqdm +from transformers.modeling_utils import is_fsdp_enabled +import torch + _cp_options.enable_load_balance = False logger = get_logger(__name__) @@ -42,7 +45,8 @@ key_cp = "cp" key_rep = "dp_shard" -# extract ssm config from hf config to be used + +# extract ssm config from hf config to be used # while swapping the mamba modules def get_ssmconfig_from_hfconfig(hf_config) -> Dict: config_ssm = {} @@ -75,6 +79,7 @@ def forward( inference_params=None, ) + # patches each mamba module with mamba cp module # mamba cp module's weights are exactly same as hf mamba module # so we reuse the state dict and the same does not need special handling @@ -94,7 +99,7 @@ def patch_mamba_layers_with_cp_head( if is_fsdp_enabled(): device = torch.device("cpu") rep_size = world_size // cp_degree - + # auto infer ddp and cp ranks # does not work on other combination of parallelisms logger.warning( From 7313944006165c92193acc602eaa030d9f366611 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 13:24:32 +0530 Subject: [PATCH 22/27] fix: dp cp loss Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/tests/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 plugins/mamba-cp/tests/__init__.py diff --git a/plugins/mamba-cp/tests/__init__.py b/plugins/mamba-cp/tests/__init__.py new file mode 100644 index 00000000..38a9531e --- /dev/null +++ b/plugins/mamba-cp/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From e7779ee658bbda2c39af212e34f64f54df9742ed Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 13:36:25 +0530 Subject: [PATCH 23/27] feat: add unit test Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/configs/mcp.yaml | 10 ++----- plugins/mamba-cp/tests/test_mcp_plugin.py | 34 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 plugins/mamba-cp/tests/test_mcp_plugin.py diff --git a/plugins/mamba-cp/configs/mcp.yaml b/plugins/mamba-cp/configs/mcp.yaml index 5f7f19c4..6194c313 100644 --- a/plugins/mamba-cp/configs/mcp.yaml +++ b/plugins/mamba-cp/configs/mcp.yaml @@ -1,8 +1,4 @@ training: - odm: - odm: - update_interval: 1 # update every step - sampling_interval: 1 # sample category for every sample - reward_type: entropy # type of reward to use - gamma: 0.1 # MAB hyper-parameter - eta: 0.1 # MAB hyper-parameter + mamba: + cp: + degree: 2 # cp degree diff --git a/plugins/mamba-cp/tests/test_mcp_plugin.py b/plugins/mamba-cp/tests/test_mcp_plugin.py new file mode 100644 index 00000000..3528f8cd --- /dev/null +++ b/plugins/mamba-cp/tests/test_mcp_plugin.py @@ -0,0 +1,34 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import os + +# Third Party +from fms_acceleration.utils import instantiate_framework, read_configuration + +# First Party +from fms_acceleration_mcp import MCPAccelerationPlugin + +# configuration +DIRNAME = os.path.dirname(__file__) +CONFIG_PATH = os.path.join(DIRNAME, "../configs/mcp.yaml") + + +def test_framework_installs_mcp_plugin(): + with instantiate_framework( + read_configuration(CONFIG_PATH), require_packages_check=False + ) as framework: + for plugin in framework.active_plugins: + assert isinstance(plugin[1], MCPAccelerationPlugin) From d5c15431deaa102b2830177bff943400406588bc Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 13:55:12 +0530 Subject: [PATCH 24/27] feat: add unit test Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/pyproject.toml | 2 +- plugins/mamba-cp/tests/test_mcp_plugin.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/mamba-cp/pyproject.toml b/plugins/mamba-cp/pyproject.toml index 804ef0a7..67948cbd 100644 --- a/plugins/mamba-cp/pyproject.toml +++ b/plugins/mamba-cp/pyproject.toml @@ -20,7 +20,7 @@ classifiers=[ "Programming Language :: Python :: 3.11", ] -dependencies = [] +dependencies = ["pytest"] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_mcp"] diff --git a/plugins/mamba-cp/tests/test_mcp_plugin.py b/plugins/mamba-cp/tests/test_mcp_plugin.py index 3528f8cd..d2fe4e4f 100644 --- a/plugins/mamba-cp/tests/test_mcp_plugin.py +++ b/plugins/mamba-cp/tests/test_mcp_plugin.py @@ -17,6 +17,7 @@ # Third Party from fms_acceleration.utils import instantiate_framework, read_configuration +import pytest # First Party from fms_acceleration_mcp import MCPAccelerationPlugin @@ -26,6 +27,10 @@ CONFIG_PATH = os.path.join(DIRNAME, "../configs/mcp.yaml") +@pytest.mark.skipif( + not pytest.importorskip("mamba_ssm", reason="mamba_ssm is not installed"), + reason="mamba_ssm is not installed", +) def test_framework_installs_mcp_plugin(): with instantiate_framework( read_configuration(CONFIG_PATH), require_packages_check=False From 2e6f24350a0214e90f412f46f00343cd3179d7a2 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 14:05:56 +0530 Subject: [PATCH 25/27] feat: add unit test Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/tests/test_mcp_plugin.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/plugins/mamba-cp/tests/test_mcp_plugin.py b/plugins/mamba-cp/tests/test_mcp_plugin.py index d2fe4e4f..c62c7d8f 100644 --- a/plugins/mamba-cp/tests/test_mcp_plugin.py +++ b/plugins/mamba-cp/tests/test_mcp_plugin.py @@ -16,12 +16,8 @@ import os # Third Party -from fms_acceleration.utils import instantiate_framework, read_configuration import pytest -# First Party -from fms_acceleration_mcp import MCPAccelerationPlugin - # configuration DIRNAME = os.path.dirname(__file__) CONFIG_PATH = os.path.join(DIRNAME, "../configs/mcp.yaml") @@ -32,6 +28,14 @@ reason="mamba_ssm is not installed", ) def test_framework_installs_mcp_plugin(): + # Third Party + # pylint: disable=import-outside-toplevel + from fms_acceleration.utils import instantiate_framework, read_configuration + + # First Party + # pylint: disable=import-outside-toplevel + from fms_acceleration_mcp import MCPAccelerationPlugin + with instantiate_framework( read_configuration(CONFIG_PATH), require_packages_check=False ) as framework: From 164c9d4226fede1be60c5ae0606bf04c6f792b2d Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 14:31:09 +0530 Subject: [PATCH 26/27] feat: add unit test Signed-off-by: Mehant Kammakomati --- plugins/mamba-cp/tox.ini | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/plugins/mamba-cp/tox.ini b/plugins/mamba-cp/tox.ini index 1a21a899..ba3e7209 100644 --- a/plugins/mamba-cp/tox.ini +++ b/plugins/mamba-cp/tox.ini @@ -12,7 +12,10 @@ commands = # install the dependencies here to ensure # the order pip install -e {toxinidir}/../framework - pytest {posargs:tests} + # if all tests skipped + # pytest should not report fail + bash -c 'pytest {posargs:tests}; ec=$?; [ "$ec" = "5" ] && exit 0 || exit $ec' +allowlist_externals = bash [testenv:lint] description = run linters From 2d2012dc0054bc2ad144765c9e9e809922683a2f Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 28 Nov 2025 16:17:26 +0530 Subject: [PATCH 27/27] nit: remove ds_store files Signed-off-by: Mehant Kammakomati --- .gitignore | 3 ++- plugins/mamba-cp/.DS_Store | Bin 6148 -> 0 bytes .../mamba-cp/src/fms_acceleration_mcp/.DS_Store | Bin 6148 -> 0 bytes 3 files changed, 2 insertions(+), 1 deletion(-) delete mode 100644 plugins/mamba-cp/.DS_Store delete mode 100644 plugins/mamba-cp/src/fms_acceleration_mcp/.DS_Store diff --git a/.gitignore b/.gitignore index e50ac0ea..94bb46d9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ *.tar.gz *.tox -*.pytest_cache \ No newline at end of file +*.pytest_cache +**/.DS_Store diff --git a/plugins/mamba-cp/.DS_Store b/plugins/mamba-cp/.DS_Store deleted file mode 100644 index 19e14ad0de00558677b2018b401848ef82a27d94..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKO-lnY5Pi`e3WD_LF@K;eUOg?P-o%S|w{B}&+L9saGV24gUgLZI z1$NlueP2J{HC26}-whtQvju%CWO&7f{S~=eax-!_w2xJRf-|Lh&NS|hJ{hBxXt3lu zTa6&rQb!B&=495aW5!wIP3=pv{?6|{<@f78PW^}Q{9FN7z!h)>j-~+bY$?MFLvLLH zSHKncQb5j!$SRm8Yz_73V53g};(%r??8{q1I5A$g#fvy7QHXOZvC-6xoacq6RQ;QYo)`&e&-{A135>obkgRd^y!x;Yz-xf*p*K74*?a#8&}{6 G3VZ-AH9E)u diff --git a/plugins/mamba-cp/src/fms_acceleration_mcp/.DS_Store b/plugins/mamba-cp/src/fms_acceleration_mcp/.DS_Store deleted file mode 100644 index 61382622f25f28f9e4986c5b19e6649b20f38e13..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK!A`?441IwiP2#d6$NT`Q_=9R1H^gP)Jjyl(V%?^-l{og(dA7r}Lb<|(>?wJ+ zV>`){mNWpuxLMo)BLG7-!5~VHh`V>_ECmY&tx@gjr@N-A_dR7!(0aMYp`P7jaCrvTCFAX78+Yn zE9rYz|B6`^tIQAeXFt}TVapmNBfqZyKJkCFtJ9?3SLF;i1I~am@J9^r&K4P78hY;x zI0MeWfdM%mBAZ~Iur<`9gPlGBhy%J+sLNYIa$>?fVQa_{O0ZO-rH1sx2$oKNOmTU_ z*3i-sbY|8`Gk-o_g3eBV)Nq8{(0gaV8R#-_Zstht|6BfLdLQ}S6t|oKXW*YP5C)U! zWXw&ayY<`l)=;X5U(<>HA)thK=M4M; F10Q>8JNp0t