diff --git a/.env.miner.template b/.env.miner.template
new file mode 100644
index 00000000..98cf6521
--- /dev/null
+++ b/.env.miner.template
@@ -0,0 +1,21 @@
+# ======= Miner Configuration (FILL IN) =======
+# Wallet
+WALLET_NAME=
+WALLET_HOTKEY=
+
+# Network
+CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
+# OTF public finney endpoint: wss://entrypoint-finney.opentensor.ai:443
+# OTF public testnet endpoint: wss://test.finney.opentensor.ai:443/
+
+# Axon port and (optionally) ip
+AXON_PORT=8091
+AXON_EXTERNAL_IP=[::]
+
+FORCE_VPERMIT=true
+
+# Device for detection models
+DEVICE=cpu
+
+# Logging
+LOGLEVEL=trace
\ No newline at end of file
diff --git a/.env.validator.template b/.env.validator.template
new file mode 100644
index 00000000..bf08e2b7
--- /dev/null
+++ b/.env.validator.template
@@ -0,0 +1,29 @@
+# ======= Validator Configuration (FILL IN) =======
+# Wallet
+WALLET_NAME=
+WALLET_HOTKEY=
+
+# API Keys
+WANDB_API_KEY=
+HUGGING_FACE_TOKEN=
+
+# Network
+CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
+# OTF public finney endpoint: wss://entrypoint-finney.opentensor.ai:443
+# OTF public testnet endpoint: wss://test.finney.opentensor.ai:443/
+
+# Validator Proxy
+PROXY_PORT=10913
+PROXY_EXTERNAL_PORT=10913
+
+# Cache config
+SN34_CACHE_DIR=~/.cache/sn34
+HEARTBEAT=true
+
+# Generator config
+GENERATION_BATCH_SIZE=3
+DEVICE=cuda
+
+# Other
+LOGLEVEL=info
+AUTO_UPDATE=false
\ No newline at end of file
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
deleted file mode 100644
index c8f67697..00000000
--- a/.github/workflows/ci.yml
+++ /dev/null
@@ -1,40 +0,0 @@
-# This workflow will install Python dependencies, run tests and lint with a single version of Python
-# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
-
-name: Continuous Integration
-
-on:
- push:
- branches: [ "main", "testnet" ]
- pull_request:
- branches: [ "main", "testnet" ]
-
-permissions:
- contents: read
-
-jobs:
- test:
-
- runs-on: ubuntu-latest
-
- steps:
- - uses: actions/checkout@v4
- - name: Set up Python 3.10
- uses: actions/setup-python@v3
- with:
- python-version: "3.10"
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install flake8 pytest pytest-asyncio
- pip install -r requirements.txt
- - name: Lint with flake8
- run: |
- # stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
- # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- #- name: Test with pytest
- # run: |
- # # run tests in tests/ dir and only fail if there are failures or errors
- # pytest tests/ --verbose --failed-first --exitfirst --disable-warnings
diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index e69de29b..00000000
diff --git a/LICENSE b/LICENSE
index 75623755..78afcb52 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,17 @@
-MIT License
+The MIT License (MIT)
+Copyright © 2023 Yuma Rao
+Copyright © 2025 BitMind
-Copyright (c) 2023 Opentensor
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
+documentation files (the "Software"), to deal in the Software without restriction, including without limitation
+the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
+and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of
+the Software.
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
+THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
diff --git a/README.md b/README.md
index 242c433e..5c795d06 100644
--- a/README.md
+++ b/README.md
@@ -1,105 +1,87 @@
-
+
-BitMind Subnet
Bittensor Subnet 34 | Deepfake Detection
-
-
-
-The BitMind Subnet is **the world's first decentralized AI-generated content detection network**. Built on Bittensor, our incentive mechanism rewards the most accurate detection algorithms, creating an adaptive defense against synthetic media.
-
-
+SN34
Deepfake Detection
+
+
+
+
+
+
+
## Decentralized Detection of AI Generated Content
-The explosive growth of generative AI technology has unleashed an unprecedented wave of synthetic media creation. These AI-generated images, videos, and other content have become remarkably sophisticated, making them virtually indistinguishable from authentic media. This development presents a critical challenge to information integrity and societal trust in the digital age, as the line between real and synthetic content continues to blur.
+The explosive growth of generative AI technology has unleashed an unprecedented wave of synthetic media creation. AI-generated audiovisual content has become remarkably sophisticated, oftentimes indistinguishable from authentic media. This development presents a critical challenge to information integrity and societal trust in the digital age, as the line between real and synthetic content continues to blur.
To address this growing challenge, SN34 aims to create the most accurate fully-generalized detection system. Here, fully-generalized means that the system is capable of detecting both synthetic and semi-synthetic media with high degrees of accuracy regardless of their content or what model generated them. Our incentive mechanism evolves alongside state-of-the-art generative AI, rewarding miners whose detection algorithms best adapt to new forms of synthetic content.
## Core Components
-> This documentation assumes basic familiarity with Bittensor concepts. For an introduction, please check out the docs: https://docs.bittensor.com/learn/bittensor-building-blocks.
+> This documentation assumes basic familiarity with [Bittensor concepts](https://docs.bittensor.com/learn/bittensor-building-blocks).
+
+Miners
-**Miners**
- Miners are tasked with running binary classifiers that discern between genuine and AI-generated content, and are rewarded based on their accuracy.
-- Miners predict a float value in [0., 1.], with values greater than 0.5 indicating the image or video is AI generated.
+- For each challenge, a miner is presented an image or video and is required to respond with a multiclass prediction [$p_{real}$, $p_{synthetic}$, $p_{semisynthetic}$] indicating whether the media is real, fully generated, or partially modified by AI.
-**Validators**
+Validators
- Validators challenge miners with a balanced mix of real and synthetic media drawn from a diverse pool of sources.
-- We continually add new datasets and generative models to our validators in order to maximize coverage of the types of diverse data. Models and datasets are defined in `bitmind/validator/config.py`.
+- We continually add new datasets and generative models to our validators in order to evolve the subnet's detection capabilities alongside advances in generative AI.
## Subnet Architecture
-> Overview of the validator neuron, miner neuron, and other components external to the subnet.
-
-
-
-**Challenge Generation and Scoring (Pink Arrows)**
-
-For each challenge, the validator randomly samples a real or synthetic image/video from the cache, applies random augmentations to the sampled media, and distributes the augmented data to 50 randomly selected miners for classification. It then scores the miners responses, and logs comprehensive challenge results to [Weights and Biases](https://wandb.ai/bitmindai/bitmind-subnet), including the generated media, original prompt, miner responses and rewards, and other challenge metadata.
-
-**Data Generation and Downloads (Blue Arrows)**:
-
-The synthetic data generator coordinates a VLM and LLM to generate prompts for our suite of text-to-image, image-to-image, and text-to-video models. Each generated image/video is written to the cache along with the prompt, generation parameters, and other metadata.
-
-The real data fetcher performs partial dataset downloads, fetching random compressed chunks of datasets from HuggingFace and unpacking random portions of these chunks into the cache along with their metadata. Partial downloads avoid requiring TBs of space for large video datasets like OpenVid1M.
-
-**Organic Traffic (Green Arrows)**
+Overview of the validator neuron, miner neuron, and other components external to the subnet.
+
+
+
+
+Challenge Generation and Scoring (Peach Arrows)
+
+ - The validator first randomly samples an image or video from its local media cache.
+ - The sampled media can be real, synthetic, or semisynthetic, and was either downloaded from an dataset on Huggingface or generated locally by one of many generative models.
+ - The sampled media is then augmented by a pipeline of random transformations, adding to the challenge difficulty and mitigating incentive mechanism gaming via lookups.
+ - The augmented media is then sent to miners for classification.
+ - The validator scores the miners responses and logs comprehensive challenge results to Weights and Biases, including the generated media, original prompt, miner responses and rewards, and other challenge metadata.
+
+
+
+
+Data Generation and Downloads (Blue Arrows)
+The blue arrows show how the validator media cache is maintained by two parallel tracks:
+
+- The synthetic data generator coordinates a VLM and LLM to generate prompts for our suite of text-to-image, image-to-image, and text-to-video models. Each generated image/video is written to the cache along with the prompt, generation parameters, and other metadata.
+- The real data fetcher performs partial dataset downloads, fetching random compressed chunks of datasets from HuggingFace and unpacking random portions of these chunks into the cache along with their metadata. Partial downloads avoid requiring TBs of space for large video datasets like OpenVid1M.
+
+
+
+
+Organic Traffic (Green Arrows)
Application requests are distributed to validators by an API server and load balancer in BitMind's cloud. A vector database caches subnet responses to avoid uncessary repetitive calls coming from salient images on the internet.
+
+
## Community
-
+
-
-For real-time discussions, community support, and regular updates, join our Discord server. Connect with developers, researchers, and users to get the most out of BitMind Subnet.
-
-## License
-This repository is licensed under the MIT License.
-```text
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-```
diff --git a/VERSION b/VERSION
new file mode 100644
index 00000000..56fea8a0
--- /dev/null
+++ b/VERSION
@@ -0,0 +1 @@
+3.0.0
\ No newline at end of file
diff --git a/autoupdate_miner_steps.sh b/autoupdate_miner_steps.sh
deleted file mode 100755
index ea6a7af5..00000000
--- a/autoupdate_miner_steps.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-#!/bin/bash
-
-# Thank you to Namoray of SN19 for their autoupdate implementation!
-# THIS FILE CONTAINS THE STEPS NEEDED TO AUTOMATICALLY UPDATE THE REPO
-# THIS FILE ITSELF MAY CHANGE FROM UPDATE TO UPDATE, SO WE CAN DYNAMICALLY FIX ANY ISSUES
-
-echo $CONDA_PREFIX
-./setup_env.sh
-echo "Autoupdate steps complete :)"
diff --git a/autoupdate_validator_steps.sh b/autoupdate_validator_steps.sh
deleted file mode 100755
index ea6a7af5..00000000
--- a/autoupdate_validator_steps.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-#!/bin/bash
-
-# Thank you to Namoray of SN19 for their autoupdate implementation!
-# THIS FILE CONTAINS THE STEPS NEEDED TO AUTOMATICALLY UPDATE THE REPO
-# THIS FILE ITSELF MAY CHANGE FROM UPDATE TO UPDATE, SO WE CAN DYNAMICALLY FIX ANY ISSUES
-
-echo $CONDA_PREFIX
-./setup_env.sh
-echo "Autoupdate steps complete :)"
diff --git a/base_miner/DFB/README.md b/base_miner/DFB/README.md
deleted file mode 100644
index 7aac3689..00000000
--- a/base_miner/DFB/README.md
+++ /dev/null
@@ -1,14 +0,0 @@
-## UCF
-
-This model has been adapted from [DeepfakeBench](https://github.com/SCLBD/DeepfakeBench).
-
-##
-
-- **Train UCF model**:
- - Use `train_ucf.py`, which will download necessary pretrained `xception` backbone weights from HuggingFace (if not present locally) and start a training job with logging outputs in `.logs/`.
- - Customize the training job by editing `config/ucf.yaml`
- - `pm2 start train_ucf.py --no-autorestart` to train a generalist detector on datasets from `DATASET_META`
- - `pm2 start train_ucf.py --no-autorestart -- --faces_only` to train a face expert detector on preprocessed-face only datasets
-
-- **Miner Neurons**:
- - The `UCF` class in `pretrained_ucf.py` is used by miner neurons to load and perform inference with pretrained UCF model weights.
\ No newline at end of file
diff --git a/base_miner/DFB/config/__init__.py b/base_miner/DFB/config/__init__.py
deleted file mode 100644
index 1ed99d5d..00000000
--- a/base_miner/DFB/config/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import os
-import sys
-current_file_path = os.path.abspath(__file__)
-parent_dir = os.path.dirname(os.path.dirname(current_file_path))
-project_root_dir = os.path.dirname(parent_dir)
-sys.path.append(parent_dir)
-sys.path.append(project_root_dir)
\ No newline at end of file
diff --git a/base_miner/DFB/config/constants.py b/base_miner/DFB/config/constants.py
deleted file mode 100644
index 2ae2373e..00000000
--- a/base_miner/DFB/config/constants.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import os
-
-CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__))
-BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to bitmind-subnet/base_miner/DFB/
-WEIGHTS_DIR = os.path.join(BASE_PATH, "weights")
-
-CONFIG_PATHS = {
- 'UCF': os.path.join(CONFIGS_DIR, "ucf.yaml"),
- 'TALL': os.path.join(CONFIGS_DIR, "tall.yaml")
-}
-
-HF_REPOS = {
- "UCF": "bitmind/ucf",
- "TALL": "bitmind/tall"
-}
-
-BACKBONE_CKPT = "xception_best.pth"
-
-DLIB_FACE_PREDICTOR_PATH = os.path.abspath(os.path.join(BASE_PATH, "../../bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat"))
\ No newline at end of file
diff --git a/base_miner/DFB/config/helpers.py b/base_miner/DFB/config/helpers.py
deleted file mode 100644
index 557bf896..00000000
--- a/base_miner/DFB/config/helpers.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import yaml
-
-
-def save_config(config, outputs_dir):
- """
- Saves a config dictionary as both a pickle file and a YAML file, ensuring only basic types are saved.
- Also, lists like 'mean' and 'std' are saved in flow style (on a single line).
-
- Args:
- config (dict): The configuration dictionary to save.
- outputs_dir (str): The directory path where the files will be saved.
- """
-
- def is_basic_type(value):
- """
- Check if a value is a basic data type that can be saved in YAML.
- Basic types include int, float, str, bool, list, and dict.
- """
- return isinstance(value, (int, float, str, bool, list, dict, type(None)))
-
- def filter_dict(data_dict):
- """
- Recursively filter out any keys from the dictionary whose values contain non-basic types (e.g., objects).
- """
- if not isinstance(data_dict, dict):
- return data_dict
-
- filtered_dict = {}
- for key, value in data_dict.items():
- if isinstance(value, dict):
- # Recursively filter nested dictionaries
- nested_dict = filter_dict(value)
- if nested_dict: # Only add non-empty dictionaries
- filtered_dict[key] = nested_dict
- elif is_basic_type(value):
- # Add if the value is a basic type
- filtered_dict[key] = value
- else:
- # Skip the key if the value is not a basic type (e.g., an object)
- print(f"Skipping key '{key}' because its value is of type {type(value)}")
-
- return filtered_dict
-
- def save_dict_to_yaml(data_dict, file_path):
- """
- Saves a dictionary to a YAML file, excluding any keys where the value is an object or contains an object.
- Additionally, ensures that specific lists (like 'mean' and 'std') are saved in flow style.
-
- Args:
- data_dict (dict): The dictionary to save.
- file_path (str): The local file path where the YAML file will be saved.
- """
-
- # Custom representer for lists to force flow style (compact lists)
- class FlowStyleList(list):
- pass
-
- def flow_style_list_representer(dumper, data):
- return dumper.represent_sequence('tag:yaml.org,2002:seq', data, flow_style=True)
-
- yaml.add_representer(FlowStyleList, flow_style_list_representer)
-
- # Preprocess specific lists to be in flow style
- if 'mean' in data_dict:
- data_dict['mean'] = FlowStyleList(data_dict['mean'])
- if 'std' in data_dict:
- data_dict['std'] = FlowStyleList(data_dict['std'])
-
- try:
- # Filter the dictionary
- filtered_dict = filter_dict(data_dict)
-
- # Save the filtered dictionary as YAML
- with open(file_path, 'w') as f:
- yaml.dump(filtered_dict, f, default_flow_style=False) # Save with default block style except for FlowStyleList
- print(f"Filtered dictionary successfully saved to {file_path}")
- except Exception as e:
- print(f"Error saving dictionary to YAML: {e}")
-
- # Save as YAML
- save_dict_to_yaml(config, outputs_dir + '/config.yaml')
\ No newline at end of file
diff --git a/base_miner/DFB/config/tall.yaml b/base_miner/DFB/config/tall.yaml
deleted file mode 100644
index 96de6a86..00000000
--- a/base_miner/DFB/config/tall.yaml
+++ /dev/null
@@ -1,89 +0,0 @@
-# model setting
-pretrained: https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth # path to a pre-trained model, if using one
-model_name: tall # model name
-
-mask_grid_size: 16
-num_classes: 2
-embed_dim: 128
-mlp_ratio: 4.0
-patch_size: 4
-window_size: [14, 14, 14, 7]
-depths: [2, 2, 18, 2]
-num_heads: [4, 8, 16, 32]
-ape: true # use absolution position embedding
-thumbnail_rows: 2
-drop_rate: 0
-drop_path_rate: 0.1
-
-# dataset
-all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
-train_dataset: [FaceForensics++]
-test_dataset: [Celeb-DF-v2]
-
-compression: c23 # compression-level for videos
-train_batchSize: 64 # training batch size
-test_batchSize: 64 # test batch size
-workers: 4 # number of data loading workers
-frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
-resolution: 224 # resolution of output image to network
-with_mask: false # whether to include mask information in the input
-with_landmark: false # whether to include facial landmark information in the input
-video_mode: True # whether to use video-level data
-clip_size: 4 # number of frames in each clip, should be square number of an integer
-dataset_type: tall
-
-# data augmentation
-use_data_augmentation: false # Add this flag to enable/disable data augmentation
-data_aug:
- flip_prob: 0.5
- rotate_prob: 0.5
- rotate_limit: [-10, 10]
- blur_prob: 0.5
- blur_limit: [3, 7]
- brightness_prob: 0.5
- brightness_limit: [-0.1, 0.1]
- contrast_limit: [-0.1, 0.1]
- quality_lower: 40
- quality_upper: 100
-
-# mean and std for normalization
-mean: [0.485, 0.456, 0.406]
-std: [0.229, 0.224, 0.225]
-
-# optimizer config
-optimizer:
- # choose between 'adam' and 'sgd'
- type: adam
- adam:
- lr: 0.00002 # learning rate
- beta1: 0.9 # beta1 for Adam optimizer
- beta2: 0.999 # beta2 for Adam optimizer
- eps: 0.00000001 # epsilon for Adam optimizer
- weight_decay: 0.0005 # weight decay for regularization
- amsgrad: false
- sgd:
- lr: 0.0002 # learning rate
- momentum: 0.9 # momentum for SGD optimizer
- weight_decay: 0.0005 # weight decay for regularization
-
-# training config
-lr_scheduler: null # learning rate scheduler
-nEpochs: 100 # number of epochs to train for
-start_epoch: 0 # manual epoch number (useful for restarts)
-save_epoch: 1 # interval epochs for saving models
-rec_iter: 100 # interval iterations for recording
-logdir: ./logs # folder to output images and logs
-manualSeed: 1024 # manual seed for random number generation
-save_ckpt: true # whether to save checkpoint
-save_feat: true # whether to save features
-
-# loss function
-loss_func: cross_entropy # loss function to use
-losstype: null
-
-# metric
-metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
-
-# cuda
-cuda: true # whether to use CUDA acceleration
-cudnn: true # whether to use CuDNN for convolution operations
\ No newline at end of file
diff --git a/base_miner/DFB/config/ucf.yaml b/base_miner/DFB/config/ucf.yaml
deleted file mode 100644
index cee1097f..00000000
--- a/base_miner/DFB/config/ucf.yaml
+++ /dev/null
@@ -1,75 +0,0 @@
-# log dir
-log_dir: ../debug_logs/ucf
-
-# model setting
-pretrained:
- hf_repo: bm_ucf
- filename: xception-best.pth
-model_name: ucf # model name
-backbone_name: xception # backbone name
-encoder_feat_dim: 512 # feature dimension of the backbone
-
-#backbone setting
-backbone_config:
- mode: adjust_channel
- num_classes: 2
- inc: 3
- dropout: false
-
-compression: c23 # compression-level for videos
-train_batchSize: 32 # training batch size
-test_batchSize: 32 # test batch size
-workers: 8 # number of data loading workers
-frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
-resolution: 256 # resolution of output image to network
-with_mask: false # whether to include mask information in the input
-with_landmark: false # whether to include facial landmark information in the input
-save_ckpt: true # whether to save checkpoint
-save_feat: true # whether to save features
-specific_task_number: 5 # default num datasets in FF++ used by DFB, overwritten in training
-
-# mean and std for normalization
-mean: [0.5, 0.5, 0.5]
-std: [0.5, 0.5, 0.5]
-
-# optimizer config
-optimizer:
- # choose between 'adam' and 'sgd'
- type: adam
- adam:
- lr: 0.0002 # learning rate
- beta1: 0.9 # beta1 for Adam optimizer
- beta2: 0.999 # beta2 for Adam optimizer
- eps: 0.00000001 # epsilon for Adam optimizer
- weight_decay: 0.0005 # weight decay for regularization
- amsgrad: false
- sgd:
- lr: 0.0002 # learning rate
- momentum: 0.9 # momentum for SGD optimizer
- weight_decay: 0.0005 # weight decay for regularization
-
-# training config
-lr_scheduler: null # learning rate scheduler
-nEpochs: 5 # number of epochs to train for
-start_epoch: 0 # manual epoch number (useful for restarts)
-save_epoch: 1 # interval epochs for saving models
-rec_iter: 100 # interval iterations for recording
-logdir: ./logs # folder to output images and logs
-manualSeed: 1024 # manual seed for random number generation
-save_ckpt: false # whether to save checkpoint
-
-# loss function
-loss_func:
- cls_loss: cross_entropy # loss function to use
- spe_loss: cross_entropy
- con_loss: contrastive_regularization
- rec_loss: l1loss
-losstype: null
-
-# metric
-metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
-
-# cuda
-
-cuda: true # whether to use CUDA acceleration
-cudnn: true # whether to use CuDNN for convolution operations
diff --git a/base_miner/DFB/config/xception.yaml b/base_miner/DFB/config/xception.yaml
deleted file mode 100644
index 9198f69c..00000000
--- a/base_miner/DFB/config/xception.yaml
+++ /dev/null
@@ -1,86 +0,0 @@
-# log dir
-log_dir: /data/home/zhiyuanyan/DeepfakeBench/logs/testing_bench
-
-# model setting
-pretrained: /data/home/zhiyuanyan/DeepfakeBench/training/pretrained/xception-b5690688.pth # path to a pre-trained model, if using one
-model_name: xception # model name
-backbone_name: xception # backbone name
-
-#backbone setting
-backbone_config:
- mode: original
- num_classes: 2
- inc: 3
- dropout: false
-
-# dataset
-all_dataset: [FaceForensics++, FF-F2F, FF-DF, FF-FS, FF-NT, FaceShifter, DeepFakeDetection, Celeb-DF-v1, Celeb-DF-v2, DFDCP, DFDC, DeeperForensics-1.0, UADFV]
-train_dataset: [FaceForensics++]
-test_dataset: [FaceForensics++, DeepFakeDetection]
-
-compression: c23 # compression-level for videos
-train_batchSize: 32 # training batch size
-test_batchSize: 32 # test batch size
-workers: 8 # number of data loading workers
-frame_num: {'train': 32, 'test': 32} # number of frames to use per video in training and testing
-resolution: 256 # resolution of output image to network
-with_mask: false # whether to include mask information in the input
-with_landmark: false # whether to include facial landmark information in the input
-
-
-# data augmentation
-use_data_augmentation: true # Add this flag to enable/disable data augmentation
-data_aug:
- flip_prob: 0.5
- rotate_prob: 0.0
- rotate_limit: [-10, 10]
- blur_prob: 0.5
- blur_limit: [3, 7]
- brightness_prob: 0.5
- brightness_limit: [-0.1, 0.1]
- contrast_limit: [-0.1, 0.1]
- quality_lower: 40
- quality_upper: 100
-
-# mean and std for normalization
-mean: [0.5, 0.5, 0.5]
-std: [0.5, 0.5, 0.5]
-
-# optimizer config
-optimizer:
- # choose between 'adam' and 'sgd'
- type: adam
- adam:
- lr: 0.0002 # learning rate
- beta1: 0.9 # beta1 for Adam optimizer
- beta2: 0.999 # beta2 for Adam optimizer
- eps: 0.00000001 # epsilon for Adam optimizer
- weight_decay: 0.0005 # weight decay for regularization
- amsgrad: false
- sgd:
- lr: 0.0002 # learning rate
- momentum: 0.9 # momentum for SGD optimizer
- weight_decay: 0.0005 # weight decay for regularization
-
-# training config
-lr_scheduler: null # learning rate scheduler
-nEpochs: 10 # number of epochs to train for
-start_epoch: 0 # manual epoch number (useful for restarts)
-save_epoch: 1 # interval epochs for saving models
-rec_iter: 100 # interval iterations for recording
-logdir: ./logs # folder to output images and logs
-manualSeed: 1024 # manual seed for random number generation
-save_ckpt: true # whether to save checkpoint
-save_feat: true # whether to save features
-
-# loss function
-loss_func: cross_entropy # loss function to use
-losstype: null
-
-# metric
-metric_scoring: auc # metric for evaluation (auc, acc, eer, ap)
-
-# cuda
-
-cuda: true # whether to use CUDA acceleration
-cudnn: true # whether to use CuDNN for convolution operations
diff --git a/base_miner/DFB/detectors/__init__.py b/base_miner/DFB/detectors/__init__.py
deleted file mode 100644
index cbaeaf92..00000000
--- a/base_miner/DFB/detectors/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import os
-import sys
-current_file_path = os.path.abspath(__file__)
-parent_dir = os.path.dirname(os.path.dirname(current_file_path))
-project_root_dir = os.path.dirname(parent_dir)
-sys.path.append(parent_dir)
-sys.path.append(project_root_dir)
-
-from metrics.registry import DETECTOR
-
-from .ucf_detector import UCFDetector
-from .tall_detector import TALLDetector
\ No newline at end of file
diff --git a/base_miner/DFB/detectors/base_detector.py b/base_miner/DFB/detectors/base_detector.py
deleted file mode 100644
index c2143ccc..00000000
--- a/base_miner/DFB/detectors/base_detector.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# author: Zhiyuan Yan
-# email: zhiyuanyan@link.cuhk.edu.cn
-# date: 2023-0706
-# description: Abstract Class for the Deepfake Detector
-
-import abc
-import torch
-import torch.nn as nn
-from typing import Union
-
-class AbstractDetector(nn.Module, metaclass=abc.ABCMeta):
- """
- All deepfake detectors should subclass this class.
- """
- def __init__(self, config=None, load_param: Union[bool, str] = False):
- """
- config: (dict)
- configurations for the model
- load_param: (False | True | Path(str))
- False Do not read; True Read the default path; Path Read the required path
- """
- super().__init__()
-
- @abc.abstractmethod
- def features(self, data_dict: dict) -> torch.tensor:
- """
- Returns the features from the backbone given the input data.
- """
- pass
-
- @abc.abstractmethod
- def forward(self, data_dict: dict, inference=False) -> dict:
- """
- Forward pass through the model, returning the prediction dictionary.
- """
- pass
-
- @abc.abstractmethod
- def classifier(self, features: torch.tensor) -> torch.tensor:
- """
- Classifies the features into classes.
- """
- pass
-
- @abc.abstractmethod
- def build_backbone(self, config):
- """
- Builds the backbone of the model.
- """
- pass
-
- @abc.abstractmethod
- def build_loss(self, config):
- """
- Builds the loss function for the model.
- """
- pass
-
- @abc.abstractmethod
- def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
- """
- Returns the losses for the model.
- """
- pass
-
- @abc.abstractmethod
- def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
- """
- Returns the training metrics for the model.
- """
- pass
\ No newline at end of file
diff --git a/base_miner/DFB/detectors/tall_detector.py b/base_miner/DFB/detectors/tall_detector.py
deleted file mode 100644
index 8a175fa3..00000000
--- a/base_miner/DFB/detectors/tall_detector.py
+++ /dev/null
@@ -1,1019 +0,0 @@
-"""
-# author: Kangran Zhao
-# email: kangranzhao@link.cuhk.edu.cn
-# date: 2023-0822
-# description: Class for the TALLDetector
-
-Functions in the Class are summarized as:
-1. __init__: Initialization
-2. build_backbone: Backbone-building
-3. build_loss: Loss-function-building
-4. features: Feature-extraction
-5. classifier: Classification
-6. get_losses: Loss-computation
-7. get_train_metrics: Training-metrics-computation
-8. get_test_metrics: Testing-metrics-computation
-9. forward: Forward-propagation
-
-Reference:
-@inproceedings{xu2023tall,
- title={TALL: Thumbnail Layout for Deepfake Video Detection},
- author={Xu, Yuting and Liang, Jian and Jia, Gengyun and Yang, Ziming and Zhang, Yanhao and He, Ran},
- booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
- pages={22658--22668},
- year={2023}
-}
-"""
-
-import logging
-import math
-import torch
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from einops import rearrange
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-from torch.hub import load_state_dict_from_url
-
-from .base_detector import AbstractDetector
-from base_miner.DFB.detectors import DETECTOR
-from base_miner.DFB.loss import LOSSFUNC
-from base_miner.DFB.metrics.base_metrics_class import calculate_metrics_for_train
-
-_logger = logging.getLogger(__name__)
-
-
-@DETECTOR.register_module(module_name='tall')
-class TALLDetector(AbstractDetector):
- def __init__(self, config, device='cuda'):
- super().__init__()
- self.device = device
- self.model = self.build_backbone(config).to(self.device)
- self.loss_func = self.build_loss(config)
-
- def build_backbone(self, config):
- model_kwargs = dict(
- num_classes=config['num_classes'],
- embed_dim=config['embed_dim'],
- mlp_ratio=config['mlp_ratio'],
- patch_size=config['patch_size'],
- window_size=config['window_size'],
- depths=config['depths'],
- num_heads=config['num_heads'],
- ape=config['ape'],
- thumbnail_rows=config['thumbnail_rows'],
- drop_rate=config['drop_rate'],
- drop_path_rate=config['drop_path_rate'],
- use_checkpoint=False,
- bottleneck=False,
- duration=config['clip_size'],
- device=self.device
- )
-
- default_cfg = {
- 'url': config['pretrained'],
- 'num_classes': 1000,
- 'input_size': (3, 224, 224),
- 'pool_size': None,
- 'crop_pct': .9,
- 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN,
- 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.proj',
- 'classifier': 'head',
- }
-
- backbone = SwinTransformer(img_size=config['resolution'], **model_kwargs)
- backbone.default_cfg = default_cfg
-
- load_pretrained(
- backbone,
- num_classes=config['num_classes'],
- in_chans=model_kwargs.get('in_chans', 3),
- filter_fn=_conv_filter,
- img_size=config['resolution'],
- pretrained_window_size=7,
- pretrained_model=''
- )
-
- return backbone
-
- def build_loss(self, config):
- loss_class = LOSSFUNC[config['loss_func']]
- loss_func = loss_class()
- return loss_func
-
- def features(self, data_dict: dict) -> torch.tensor:
- bs, t, c, h, w = data_dict['image'].shape
- inputs = data_dict['image'].view(bs, t * c, h, w)
- pred = self.model(inputs)
- return pred
-
- def classifier(self, features: torch.tensor):
- pass
-
- def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
- label = data_dict['label'].long()
- pred = pred_dict['cls']
- loss = self.loss_func(pred, label)
- loss_dict = {'overall': loss}
- return loss_dict
-
- def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
- label = data_dict['label']
- pred = pred_dict['cls']
- auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
- metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
- return metric_batch_dict
-
- def forward(self, data_dict: dict, inference=False) -> dict:
- pred = self.features(data_dict)
- prob = torch.softmax(pred, dim=1)[:, 1]
- pred_dict = {'cls': pred, 'prob': prob, 'feat': prob}
- return pred_dict
-
-
-class Mlp(nn.Module):
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """Partition input tensor into windows.
-
- Args:
- x: Input tensor of shape (B, H, W, C)
- window_size (int): Size of each window
-
- Returns:
- windows: Output tensor of shape (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
- windows = windows.view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """Reverse window partitioning.
-
- Args:
- windows: Input tensor of shape (num_windows*B, window_size, window_size, C)
- window_size (int): Size of each window
- H (int): Height of original image
- W (int): Width of original image
-
- Returns:
- x: Output tensor of shape (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size,
- window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
- x = x.view(B, H, W, -1)
- return x
-
-
-class WindowAttention(nn.Module):
- """Window based multi-head self attention (W-MSA) module with relative position bias.
-
- It supports both shifted and non-shifted window attention.
-
- Args:
- dim (int): Number of input channels
- window_size (tuple[int]): Height and width of window
- num_heads (int): Number of attention heads
- qkv_bias (bool, optional): Add learnable bias to query, key, value.
- Default: True
- qk_scale (float | None, optional): Override default qk scale of
- head_dim ** -0.5 if set
- attn_drop (float, optional): Dropout ratio of attention weight.
- Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True,
- qk_scale=None, attn_drop=0., proj_drop=0.):
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.num_heads = num_heads
- head_dim = dim // num_heads
- self.scale = qk_scale or head_dim ** -0.5
-
- # Define parameter table of relative position bias
- self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
- num_heads))
-
- # Get pair-wise relative position index for each token in window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
- coords_flatten = torch.flatten(coords, 1)
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
- relative_coords[:, :, 0] += self.window_size[0] - 1
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1)
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- trunc_normal_(self.relative_position_bias_table, std=.02)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """Forward pass.
-
- Args:
- x: Input features with shape (num_windows*B, N, C)
- mask: (0/-inf) mask with shape (num_windows, Wh*Ww, Wh*Ww) or None
-
- Returns:
- Output tensor after attention
- """
- B_, N, C = x.shape
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
- qkv = qkv.permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
-
- q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
-
- relative_position_bias = self.relative_position_bias_table[
- self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1],
- self.window_size[0] * self.window_size[1], -1)
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
- attn = attn + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- """Extra string representation."""
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- """Calculate FLOPs for one window."""
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-
-class SwinTransformerBlock(nn.Module):
- """Swin Transformer Block.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(
- self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, bottleneck=False, use_checkpoint=False
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- self.use_checkpoint = use_checkpoint
-
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim,
- window_size=to_2tuple(self.window_size),
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- attn_drop=attn_drop,
- proj_drop=drop
- )
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop
- )
-
- if self.shift_size > 0:
- # calculate attention mask for SW-MSA
- H, W = self.input_resolution
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None)
- )
- w_slices = (
- slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None)
- )
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- # nW, window_size, window_size, 1
- mask_windows = window_partition(img_mask, self.window_size)
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(
- attn_mask != 0, float(-100.0)
- ).masked_fill(attn_mask == 0, float(0.0))
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def forward_attn(self, x):
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
-
- x = self.norm1(x)
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(
- x,
- shifts=(-self.shift_size, -self.shift_size),
- dims=(1, 2)
- )
- else:
- shifted_x = x
-
- # partition windows
- # nW*B, window_size, window_size, C
- x_windows = window_partition(shifted_x, self.window_size)
- # nW*B, window_size*window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
-
- # W-MSA/SW-MSA
- # nW*B, window_size*window_size, C
- attn_windows = self.attn(x_windows, mask=self.attn_mask)
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(
- shifted_x,
- shifts=(self.shift_size, self.shift_size),
- dims=(1, 2)
- )
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
-
- return x
-
- def forward_mlp(self, x):
- return self.drop_path(self.mlp(self.norm2(x)))
-
- def forward(self, x):
- shortcut = x
- if self.use_checkpoint:
- x = checkpoint.checkpoint(self.forward_attn, x)
- else:
- x = self.forward_attn(x)
- x = shortcut + self.drop_path(x)
-
- if self.use_checkpoint:
- x = x + checkpoint.checkpoint(self.forward_mlp, x)
- else:
- x = x + self.forward_mlp(x)
-
- return x
-
- def extra_repr(self) -> str:
- return (
- f"dim={self.dim}, input_resolution={self.input_resolution}, "
- f"num_heads={self.num_heads}, window_size={self.window_size}, "
- f"shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
- )
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-
-class PatchMerging(nn.Module):
- """Patch Merging Layer.
-
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(4 * dim)
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.norm(x)
- x = self.reduction(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.dim
- flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- return flops
-
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- """
-
- def __init__(
- self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- bottleneck=False
- ):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- bottleneck=bottleneck if i == depth - 1 else False,
- use_checkpoint=use_checkpoint
- ) for i in range(depth)
- ])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(
- input_resolution,
- dim=dim,
- norm_layer=norm_layer
- )
- else:
- self.downsample = None
-
- def forward(self, x):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x)
- else:
- x = blk(x)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
-
-class PatchEmbed(nn.Module):
- r"""Image to Patch Embedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(
- self,
- img_size=(224, 224),
- patch_size=4,
- in_chans=3,
- embed_dim=96,
- norm_layer=None
- ):
- super().__init__()
- # img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [
- img_size[0] // patch_size[0],
- img_size[1] // patch_size[1]
- ]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(
- in_chans,
- embed_dim,
- kernel_size=patch_size,
- stride=patch_size
- )
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- assert H == self.img_size[0] and W == self.img_size[1], \
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (
- self.patch_size[0] * self.patch_size[1]
- )
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-
-class SwinTransformer(nn.Module):
- r""" Swin Transformer
- A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
- https://arxiv.org/pdf/2103.14030
- Args:
- img_size (int | tuple(int)): Input image size. Default 224
- patch_size (int | tuple(int)): Patch size. Default: 4
- in_chans (int): Number of input image channels. Default: 3
- num_classes (int): Number of classes for classification head. Default: 1000
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- """
-
- def __init__(
- self, duration=8, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
- embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
- window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, thumbnail_rows=1, bottleneck=False, device='cuda', **kwargs
- ):
- super().__init__()
-
- self.duration = duration # 4
- self.num_classes = num_classes # 2
- self.num_layers = len(depths) # [2, 2, 18, 2]
- self.embed_dim = embed_dim # 128
- self.ape = ape # True
- self.patch_norm = patch_norm # False
- self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
- self.mlp_ratio = mlp_ratio # 4 = default
- self.thumbnail_rows = thumbnail_rows # 2
- self.device = device
-
- self.img_size = img_size # 224
- self.window_size = ([window_size for _ in depths] if not isinstance(window_size, list)
- else window_size)
-
- self.frame_padding = self.duration % thumbnail_rows # 0
- if self.frame_padding != 0:
- self.frame_padding = self.thumbnail_rows - self.frame_padding
- self.duration += self.frame_padding
-
- # split image into non-overlapping patches
- thumbnail_dim = (thumbnail_rows, self.duration // thumbnail_rows) # (2, 2)
- thumbnail_size = (img_size * thumbnail_dim[0], img_size * thumbnail_dim[1])
-
- self.patch_embed = PatchEmbed(
- img_size=(img_size, img_size),
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None
- )
- num_patches = self.patch_embed.num_patches # 16
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution # [56, 56]
-
- # absolute position embedding
- if self.ape: # True
- self.frame_pos_embed = nn.Parameter(torch.zeros(1, self.duration, embed_dim))
- trunc_normal_(self.frame_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
-
- # build layers
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = BasicLayer(
- dim=int(embed_dim * 2 ** i_layer),
- input_resolution=(
- patches_resolution[0] // (2 ** i_layer),
- patches_resolution[1] // (2 ** i_layer)),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=self.window_size[i_layer],
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop_rate,
- attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- norm_layer=norm_layer,
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
- use_checkpoint=use_checkpoint,
- bottleneck=bottleneck
- )
- self.layers.append(layer)
-
- self.norm = norm_layer(self.num_features)
- self.avgpool = nn.AdaptiveAvgPool1d(1)
- self.head = (nn.Linear(self.num_features, num_classes)
- if num_classes > 0 else nn.Identity())
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed', 'frame_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def create_thumbnail(self, x):
- input_size = x.shape[-2:]
- if input_size != to_2tuple(self.img_size):
- x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear')
- x = rearrange(x, 'b (th tw c) h w -> b c (th h) (tw w)',
- th=self.thumbnail_rows, c=3)
- return x
-
- def pad_frames(self, x):
- frame_num = self.duration - self.frame_padding
- x = x.view((-1, 3 * frame_num) + x.size()[2:])
- x_padding = torch.zeros((x.shape[0], 3 * self.frame_padding) +
- x.size()[2:]).to(self.device)
- x = torch.cat((x, x_padding), dim=1)
- assert x.shape[1] == 3 * self.duration, (
- 'frame number %d not the same as adjusted input size %d' %
- (x.shape[1], 3 * self.duration))
-
- return x
-
- # need to find a better way to do this, maybe torch.fold?
- def create_image_pos_embed(self):
- img_rows, img_cols = self.patches_resolution # (56, 56)
- _, _, T = self.frame_pos_embed.shape # (1, 4, embed)
- rows = img_rows // self.thumbnail_rows # 28
- cols = img_cols // (self.duration // self.thumbnail_rows) # 28
- img_pos_embed = torch.zeros(img_rows, img_cols, T).to(self.device) # [56, 56, embed]
- for i in range(self.duration):
- r_indx = (i // self.thumbnail_rows) * rows
- c_indx = (i % self.thumbnail_rows) * cols
- img_pos_embed[r_indx:r_indx + rows, c_indx:c_indx + cols] = self.frame_pos_embed[0, i]
-
- return img_pos_embed.reshape(-1, T) # [56*56, embed]
-
- def forward_features(self, x):
- if self.frame_padding > 0:
- x = self.pad_frames(x)
- else:
- x = x.view((-1, 3 * self.duration) + x.size()[2:])
-
- x = self.create_thumbnail(x)
- x = nn.functional.interpolate(x, size=self.img_size, mode='bilinear') # [B, 3, 224, 224]
-
- x = self.patch_embed(x) # [B, 56*56, embed]
- if self.ape:
- img_pos_embed = self.create_image_pos_embed()
- x = x + img_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x)
-
- x = self.norm(x) # B L C
- x = self.avgpool(x.transpose(1, 2)) # B C 1
- x = torch.flatten(x, 1)
- return x
-
- def forward(self, x):
- x = self.forward_features(x)
- x = self.head(x)
- return x
-
- def flops(self):
- flops = 0
- flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
- flops += layer.flops()
- flops += (self.num_features * self.patches_resolution[0] *
- self.patches_resolution[1] // (2 ** self.num_layers))
- flops += self.num_features * self.num_classes
- return flops
-
-def load_pretrained(
- model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224,
- num_patches=196, pretrained_window_size=7, pretrained_model="", strict=True
-):
- if cfg is None:
- cfg = getattr(model, 'default_cfg')
- if cfg is None or 'url' not in cfg or not cfg['url']:
- _logger.warning("Pretrained model URL is invalid, using random initialization.")
- return
-
- if len(pretrained_model) == 0:
- state_dict = load_state_dict_from_url(cfg['url'], map_location='cpu')
- else:
- try:
- state_dict = torch.load(pretrained_model)['model']
- except:
- state_dict = torch.load(pretrained_model)
-
- if filter_fn is not None:
- state_dict = filter_fn(state_dict)
-
- if in_chans == 1:
- conv1_name = cfg['first_conv']
- _logger.info(
- 'Converting first conv (%s) pretrained weights from 3 to 1 channel',
- conv1_name
- )
- conv1_weight = state_dict[conv1_name + '.weight']
- conv1_type = conv1_weight.dtype
- conv1_weight = conv1_weight.float()
- O, I, J, K = conv1_weight.shape
- if I > 3:
- assert conv1_weight.shape[1] % 3 == 0
- # For models with space2depth stems
- conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
- conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
- else:
- conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
- conv1_weight = conv1_weight.to(conv1_type)
- state_dict[conv1_name + '.weight'] = conv1_weight
- elif in_chans != 3:
- conv1_name = cfg['first_conv']
- conv1_weight = state_dict[conv1_name + '.weight']
- conv1_type = conv1_weight.dtype
- conv1_weight = conv1_weight.float()
- O, I, J, K = conv1_weight.shape
- if I != 3:
- _logger.warning(
- 'Deleting first conv (%s) from pretrained weights.',
- conv1_name
- )
- del state_dict[conv1_name + '.weight']
- strict = False
- else:
- _logger.info(
- 'Repeating first conv (%s) weights in channel dim.',
- conv1_name
- )
- repeat = int(math.ceil(in_chans / 3))
- conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
- conv1_weight *= (3 / float(in_chans))
- conv1_weight = conv1_weight.to(conv1_type)
- state_dict[conv1_name + '.weight'] = conv1_weight
-
- classifier_name = cfg['classifier']
- if num_classes == 1000 and cfg['num_classes'] == 1001:
- # special case for imagenet trained models with extra background class
- classifier_weight = state_dict[classifier_name + '.weight']
- state_dict[classifier_name + '.weight'] = classifier_weight[1:]
- classifier_bias = state_dict[classifier_name + '.bias']
- state_dict[classifier_name + '.bias'] = classifier_bias[1:]
- elif num_classes != cfg['num_classes']:
- # discard fully connected for all other differences
- del state_dict['model'][classifier_name + '.weight']
- del state_dict['model'][classifier_name + '.bias']
- strict = False
- '''
- ## Resizing the positional embeddings in case they don't match
- if img_size != cfg['input_size'][1]:
- pos_embed = state_dict['pos_embed']
- cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
- other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
- new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
- new_pos_embed = new_pos_embed.transpose(1, 2)
- new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
- state_dict['pos_embed'] = new_pos_embed
- '''
-
- # remove window_size related parameters
- window_size = (model.window_size)[0]
- print(pretrained_window_size, window_size)
-
- new_state_dict = state_dict['model'].copy()
- for key in state_dict['model']:
- if 'attn_mask' in key:
- del new_state_dict[key]
-
- if 'relative_position_index' in key:
- del new_state_dict[key]
-
- # resize it
- if 'relative_position_bias_table' in key:
- pretrained_table = state_dict['model'][key]
- pretrained_table_size = int(math.sqrt(pretrained_table.shape[0]))
- table_size = int(math.sqrt(model.state_dict()[key].shape[0]))
- if pretrained_table_size != table_size:
- table = pretrained_table.permute(1, 0).view(1, -1, pretrained_table_size, pretrained_table_size)
- table = nn.functional.interpolate(table, size=table_size, mode='bilinear')
- table = table.view(-1, table_size * table_size).permute(1, 0)
- new_state_dict[key] = table
-
- for key in model.state_dict():
- if 'bottleneck_norm' in key:
- attn_key = key.replace('bottleneck_norm', 'norm1')
- # print (key, attn_key)
- new_state_dict[key] = new_state_dict[attn_key]
-
- print('loading weights....')
- ## Loading the weights
- model.load_state_dict(new_state_dict, strict=False)
-
-
-def _conv_filter(state_dict, patch_size=4):
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
- out_dict = {}
- for k, v in state_dict.items():
- if 'patch_embed.proj.weight' in k:
- if v.shape[-1] != patch_size:
- patch_size = v.shape[-1]
- v = v.reshape((v.shape[0], 3, patch_size, patch_size))
- out_dict[k] = v
- return out_dict
\ No newline at end of file
diff --git a/base_miner/DFB/detectors/ucf_detector.py b/base_miner/DFB/detectors/ucf_detector.py
deleted file mode 100644
index 51b16d81..00000000
--- a/base_miner/DFB/detectors/ucf_detector.py
+++ /dev/null
@@ -1,486 +0,0 @@
-'''
-# Source: https://github.com/SCLBD/DeepfakeBench/blob/main/training/detectors/ucf_detector.py
-# author: Zhiyuan Yan
-# email: zhiyuanyan@link.cuhk.edu.cn
-# date: 2023-0706
-# description: Class for the UCFDetector
-
-Functions in the Class are summarized as:
-1. __init__: Initialization
-2. build_backbone: Backbone-building
-3. build_loss: Loss-function-building
-4. features: Feature-extraction
-5. classifier: Classification
-6. get_losses: Loss-computation
-7. get_train_metrics: Training-metrics-computation
-8. get_test_metrics: Testing-metrics-computation
-9. forward: Forward-propagation
-
-Reference:
-@article{yan2023ucf,
- title={UCF: Uncovering Common Features for Generalizable Deepfake Detection},
- author={Yan, Zhiyuan and Zhang, Yong and Fan, Yanbo and Wu, Baoyuan},
- journal={arXiv preprint arXiv:2304.13949},
- year={2023}
-}
-'''
-
-import os
-import datetime
-import logging
-import random
-import numpy as np
-from sklearn import metrics
-from typing import Union
-from collections import defaultdict
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-from torch.nn import DataParallel
-from torch.utils.tensorboard import SummaryWriter
-
-from metrics.base_metrics_class import calculate_metrics_for_train
-
-from DFB.detectors.base_detector import AbstractDetector
-from DFB.detectors import DETECTOR
-from DFB.networks import BACKBONE
-from DFB.loss import LOSSFUNC
-from DFB.config.constants import WEIGHTS_DIR
-
-logger = logging.getLogger(__name__)
-
-
-@DETECTOR.register_module(module_name='ucf')
-class UCFDetector(AbstractDetector):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.num_classes = config['backbone_config']['num_classes']
- self.encoder_feat_dim = config['encoder_feat_dim']
- self.half_fingerprint_dim = self.encoder_feat_dim//2
-
- self.encoder_f = self.build_backbone(config)
- self.encoder_c = self.build_backbone(config)
-
- self.loss_func = self.build_loss(config)
- self.prob, self.label = [], []
- self.correct, self.total = 0, 0
-
- # basic function
- self.lr = nn.LeakyReLU(inplace=True)
- self.do = nn.Dropout(0.2)
- self.pool = nn.AdaptiveAvgPool2d(1)
-
- # conditional gan
- self.con_gan = Conditional_UNet()
-
- # head
- specific_task_number = config['specific_task_number']
-
- self.head_spe = Head(
- in_f=self.half_fingerprint_dim,
- hidden_dim=self.encoder_feat_dim,
- out_f=specific_task_number
- )
- self.head_sha = Head(
- in_f=self.half_fingerprint_dim,
- hidden_dim=self.encoder_feat_dim,
- out_f=self.num_classes
- )
- self.block_spe = Conv2d1x1(
- in_f=self.encoder_feat_dim,
- hidden_dim=self.half_fingerprint_dim,
- out_f=self.half_fingerprint_dim
- )
- self.block_sha = Conv2d1x1(
- in_f=self.encoder_feat_dim,
- hidden_dim=self.half_fingerprint_dim,
- out_f=self.half_fingerprint_dim
- )
-
- def build_backbone(self, config):
- # prepare the backbone
- backbone_class = BACKBONE[config['backbone_name']]
- model_config = config['backbone_config']
- backbone = backbone_class(model_config)
-
- if 'pretrained' in config:
- pretrained_path = config['pretrained']
- if isinstance(pretrained_path, dict):
- if 'local_path' in pretrained_path:
- pretrained_path = pretrained_path['local_path']
- elif 'filename' in pretrained_path:
- pretrained_path = pretrained_path['filename']
- else:
- pretrained_path = pretrained_path.split('/')[-1]
-
- if not os.path.isabs(pretrained_path):
- pretrained_path = os.path.join(WEIGHTS_DIR, pretrained_path)
-
- logger.info(f"Loading {pretrained_path}")
- state_dict = torch.load(pretrained_path)
- for name, weights in state_dict.items():
- if 'pointwise' in name:
- state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
- state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k}
- backbone.load_state_dict(state_dict, False)
- logger.info('Load pretrained model successfully!')
- return backbone
-
- def build_loss(self, config):
- cls_loss_class = LOSSFUNC[config['loss_func']['cls_loss']]
- spe_loss_class = LOSSFUNC[config['loss_func']['spe_loss']]
- con_loss_class = LOSSFUNC[config['loss_func']['con_loss']]
- rec_loss_class = LOSSFUNC[config['loss_func']['rec_loss']]
- cls_loss_func = cls_loss_class()
- spe_loss_func = spe_loss_class()
- con_loss_func = con_loss_class(margin=3.0)
- rec_loss_func = rec_loss_class()
- loss_func = {
- 'cls': cls_loss_func,
- 'spe': spe_loss_func,
- 'con': con_loss_func,
- 'rec': rec_loss_func,
- }
- return loss_func
-
- def features(self, data_dict: dict) -> torch.tensor:
- cat_data = data_dict['image']
- # encoder
- f_all = self.encoder_f.features(cat_data)
- c_all = self.encoder_c.features(cat_data)
- feat_dict = {'forgery': f_all, 'content': c_all}
- return feat_dict
-
- def classifier(self, features: torch.tensor) -> torch.tensor:
- # classification, multi-task
- # split the features into the specific and common forgery
- f_spe = self.block_spe(features)
- f_share = self.block_sha(features)
- return f_spe, f_share
-
- def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
- if 'label_spe' in data_dict and 'recontruction_imgs' in pred_dict:
- return self.get_train_losses(data_dict, pred_dict)
- else: # test mode
- return self.get_test_losses(data_dict, pred_dict)
-
- def get_train_losses(self, data_dict: dict, pred_dict: dict) -> dict:
- # get combined, real, fake imgs
- cat_data = data_dict['image']
- real_img, fake_img = cat_data.chunk(2, dim=0)
- # get the reconstruction imgs
- reconstruction_image_1, \
- reconstruction_image_2, \
- self_reconstruction_image_1, \
- self_reconstruction_image_2 \
- = pred_dict['recontruction_imgs']
- # get label
- label = data_dict['label']
- label_spe = data_dict['label_spe']
- # get pred
- pred = pred_dict['cls']
- pred_spe = pred_dict['cls_spe']
-
- # 1. classification loss for common features
- loss_sha = self.loss_func['cls'](pred, label)
-
- # 2. classification loss for specific features
- loss_spe = self.loss_func['spe'](pred_spe, label_spe)
-
- # 3. reconstruction loss
- self_loss_reconstruction_1 = self.loss_func['rec'](fake_img, self_reconstruction_image_1)
- self_loss_reconstruction_2 = self.loss_func['rec'](real_img, self_reconstruction_image_2)
- cross_loss_reconstruction_1 = self.loss_func['rec'](fake_img, reconstruction_image_2)
- cross_loss_reconstruction_2 = self.loss_func['rec'](real_img, reconstruction_image_1)
- loss_reconstruction = \
- self_loss_reconstruction_1 + self_loss_reconstruction_2 + \
- cross_loss_reconstruction_1 + cross_loss_reconstruction_2
-
- # 4. constrative loss
- common_features = pred_dict['feat']
- specific_features = pred_dict['feat_spe']
- loss_con = self.loss_func['con'](common_features, specific_features, label_spe)
-
- # 5. total loss
- loss = loss_sha + 0.1*loss_spe + 0.3*loss_reconstruction + 0.05*loss_con
- loss_dict = {
- 'overall': loss,
- 'common': loss_sha,
- 'specific': loss_spe,
- 'reconstruction': loss_reconstruction,
- 'contrastive': loss_con,
- }
- return loss_dict
-
- def get_test_losses(self, data_dict: dict, pred_dict: dict) -> dict:
- # get label
- label = data_dict['label']
- # get pred
- pred = pred_dict['cls']
- # for test mode, only classification loss for common features
- loss = self.loss_func['cls'](pred, label)
- loss_dict = {'common': loss}
- return loss_dict
-
- def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
- def get_accracy(label, output):
- _, prediction = torch.max(output, 1) # argmax
- correct = (prediction == label).sum().item()
- accuracy = correct / prediction.size(0)
- return accuracy
-
- # get pred and label
- label = data_dict['label']
- pred = pred_dict['cls']
- label_spe = data_dict['label_spe']
- pred_spe = pred_dict['cls_spe']
-
- # compute metrics for batch data
- auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
- acc_spe = get_accracy(label_spe.detach(), pred_spe.detach())
- metric_batch_dict = {'acc': acc, 'acc_spe': acc_spe, 'auc': auc, 'eer': eer, 'ap': ap}
- # we dont compute the video-level metrics for training
- return metric_batch_dict
-
- def forward(self, data_dict: dict, inference=False) -> dict:
- # split the features into the content and forgery
- features = self.features(data_dict)
- forgery_features, content_features = features['forgery'], features['content']
- # get the prediction by classifier (split the common and specific forgery)
- f_spe, f_share = self.classifier(forgery_features)
-
- if inference:
- # inference only consider share loss
- out_sha, sha_feat = self.head_sha(f_share)
- out_spe, spe_feat = self.head_spe(f_spe)
- prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
- self.prob.append(
- prob_sha
- .detach()
- .squeeze()
- .cpu()
- .numpy()
- )
- _, prediction_class = torch.max(out_sha, 1)
- if 'label' in data_dict:
- self.label.append(
- data_dict['label']
- .detach()
- .squeeze()
- .cpu()
- .numpy()
- )
- # deal with acc
- common_label = (data_dict['label'] >= 1)
- correct = (prediction_class == common_label).sum().item()
- self.correct += correct
- self.total += data_dict['label'].size(0)
-
- pred_dict = {'cls': out_sha, 'feat': sha_feat}
- return pred_dict
-
- bs = f_share.size(0)
- # using idx aug in the training mode
- aug_idx = random.random()
- if aug_idx < 0.7:
- # real
- idx_list = list(range(0, bs//2))
- random.shuffle(idx_list)
- f_share[0: bs//2] = f_share[idx_list]
- # fake
- idx_list = list(range(bs//2, bs))
- random.shuffle(idx_list)
- f_share[bs//2: bs] = f_share[idx_list]
-
- # concat spe and share to obtain new_f_all
- f_all = torch.cat((f_spe, f_share), dim=1)
-
- # reconstruction loss
- f2, f1 = f_all.chunk(2, dim=0)
- c2, c1 = content_features.chunk(2, dim=0)
-
- # ==== self reconstruction ==== #
- # f1 + c1 -> f11, f11 + c1 -> near~I1
- self_reconstruction_image_1 = self.con_gan(f1, c1)
-
- # f2 + c2 -> f2, f2 + c2 -> near~I2
- self_reconstruction_image_2 = self.con_gan(f2, c2)
-
- # ==== cross combine ==== #
- reconstruction_image_1 = self.con_gan(f1, c2)
- reconstruction_image_2 = self.con_gan(f2, c1)
-
- # head for spe and sha
- out_spe, spe_feat = self.head_spe(f_spe)
- out_sha, sha_feat = self.head_sha(f_share)
-
- # get the probability of the pred
- prob_sha = torch.softmax(out_sha, dim=1)[:, 1]
- prob_spe = torch.softmax(out_spe, dim=1)[:, 1]
-
- # build the prediction dict for each output
- pred_dict = {
- 'cls': out_sha,
- 'prob': prob_sha,
- 'feat': sha_feat,
- 'cls_spe': out_spe,
- 'prob_spe': prob_spe,
- 'feat_spe': spe_feat,
- 'feat_content': content_features,
- 'recontruction_imgs': (
- reconstruction_image_1,
- reconstruction_image_2,
- self_reconstruction_image_1,
- self_reconstruction_image_2
- )
- }
- return pred_dict
-
-def sn_double_conv(in_channels, out_channels):
- return nn.Sequential(
- nn.utils.spectral_norm(
- nn.Conv2d(in_channels, in_channels, 3, padding=1)),
- nn.utils.spectral_norm(
- nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=2)),
- nn.LeakyReLU(0.2, inplace=True)
- )
-
-def r_double_conv(in_channels, out_channels):
- return nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
- nn.ReLU(inplace=True)
- )
-
-class AdaIN(nn.Module):
- def __init__(self, eps=1e-5):
- super().__init__()
- self.eps = eps
- # self.l1 = nn.Linear(num_classes, in_channel*4, bias=True) #bias is good :)
-
- def c_norm(self, x, bs, ch, eps=1e-7):
- # assert isinstance(x, torch.cuda.FloatTensor)
- x_var = x.var(dim=-1) + eps
- x_std = x_var.sqrt().view(bs, ch, 1, 1)
- x_mean = x.mean(dim=-1).view(bs, ch, 1, 1)
- return x_std, x_mean
-
- def forward(self, x, y):
- assert x.size(0)==y.size(0)
- size = x.size()
- bs, ch = size[:2]
- x_ = x.view(bs, ch, -1)
- y_ = y.reshape(bs, ch, -1)
- x_std, x_mean = self.c_norm(x_, bs, ch, eps=self.eps)
- y_std, y_mean = self.c_norm(y_, bs, ch, eps=self.eps)
- out = ((x - x_mean.expand(size)) / x_std.expand(size)) \
- * y_std.expand(size) + y_mean.expand(size)
- return out
-
-class Conditional_UNet(nn.Module):
-
- def init_weight(self, std=0.2):
- for m in self.modules():
- cn = m.__class__.__name__
- if cn.find('Conv') != -1:
- m.weight.data.normal_(0., std)
- elif cn.find('Linear') != -1:
- m.weight.data.normal_(1., std)
- m.bias.data.fill_(0)
-
- def __init__(self):
- super(Conditional_UNet, self).__init__()
-
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
- self.maxpool = nn.MaxPool2d(2)
- self.dropout = nn.Dropout(p=0.3)
- #self.dropout_half = HalfDropout(p=0.3)
-
- self.adain3 = AdaIN()
- self.adain2 = AdaIN()
- self.adain1 = AdaIN()
-
- self.dconv_up3 = r_double_conv(512, 256)
- self.dconv_up2 = r_double_conv(256, 128)
- self.dconv_up1 = r_double_conv(128, 64)
-
- self.conv_last = nn.Conv2d(64, 3, 1)
- self.up_last = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
- self.activation = nn.Tanh()
- #self.init_weight()
-
- def forward(self, c, x): # c is the style and x is the content
- x = self.adain3(x, c)
- x = self.upsample(x)
- x = self.dropout(x)
- x = self.dconv_up3(x)
- c = self.upsample(c)
- c = self.dropout(c)
- c = self.dconv_up3(c)
-
- x = self.adain2(x, c)
- x = self.upsample(x)
- x = self.dropout(x)
- x = self.dconv_up2(x)
- c = self.upsample(c)
- c = self.dropout(c)
- c = self.dconv_up2(c)
-
- x = self.adain1(x, c)
- x = self.upsample(x)
- x = self.dropout(x)
- x = self.dconv_up1(x)
-
- x = self.conv_last(x)
- out = self.up_last(x)
-
- return self.activation(out)
-
-class MLP(nn.Module):
- def __init__(self, in_f, hidden_dim, out_f):
- super(MLP, self).__init__()
- self.pool = nn.AdaptiveAvgPool2d(1)
- self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
- nn.LeakyReLU(inplace=True),
- nn.Linear(hidden_dim, hidden_dim),
- nn.LeakyReLU(inplace=True),
- nn.Linear(hidden_dim, out_f),)
-
- def forward(self, x):
- x = self.pool(x)
- x = self.mlp(x)
- return x
-
-class Conv2d1x1(nn.Module):
- def __init__(self, in_f, hidden_dim, out_f):
- super(Conv2d1x1, self).__init__()
- self.conv2d = nn.Sequential(nn.Conv2d(in_f, hidden_dim, 1, 1),
- nn.LeakyReLU(inplace=True),
- nn.Conv2d(hidden_dim, hidden_dim, 1, 1),
- nn.LeakyReLU(inplace=True),
- nn.Conv2d(hidden_dim, out_f, 1, 1),)
-
- def forward(self, x):
- x = self.conv2d(x)
- return x
-
-class Head(nn.Module):
- def __init__(self, in_f, hidden_dim, out_f):
- super(Head, self).__init__()
- self.do = nn.Dropout(0.2)
- self.pool = nn.AdaptiveAvgPool2d(1)
- self.mlp = nn.Sequential(nn.Linear(in_f, hidden_dim),
- nn.LeakyReLU(inplace=True),
- nn.Linear(hidden_dim, out_f),)
-
- def forward(self, x):
- bs = x.size()[0]
- x_feat = self.pool(x).view(bs, -1)
- x = self.mlp(x_feat)
- x = self.do(x)
- return x, x_feat
-
diff --git a/base_miner/DFB/logger.py b/base_miner/DFB/logger.py
deleted file mode 100644
index 9ee268d4..00000000
--- a/base_miner/DFB/logger.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import os
-import logging
-
-import torch.distributed as dist
-
-class RankFilter(logging.Filter):
- def __init__(self, rank):
- super().__init__()
- self.rank = rank
-
- def filter(self, record):
- return dist.get_rank() == self.rank
-
-def create_logger(log_path):
- # Create log path
- if os.path.isdir(os.path.dirname(log_path)):
- os.makedirs(os.path.dirname(log_path), exist_ok=True)
-
- # Create logger object
- logger = logging.getLogger()
- logger.setLevel(logging.INFO)
- # Create file handler and set the formatter
- fh = logging.FileHandler(log_path)
- formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
- fh.setFormatter(formatter)
-
- # Add the file handler to the logger
- logger.addHandler(fh)
-
- # Add a stream handler to print to console
- sh = logging.StreamHandler()
- sh.setLevel(logging.INFO) # Set logging level for stream handler
- sh.setFormatter(formatter)
- logger.addHandler(sh)
-
- return logger
\ No newline at end of file
diff --git a/base_miner/DFB/loss/__init__.py b/base_miner/DFB/loss/__init__.py
deleted file mode 100644
index 9ad78e9a..00000000
--- a/base_miner/DFB/loss/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import os
-import sys
-current_file_path = os.path.abspath(__file__)
-parent_dir = os.path.dirname(os.path.dirname(current_file_path))
-project_root_dir = os.path.dirname(parent_dir)
-sys.path.append(parent_dir)
-sys.path.append(project_root_dir)
-
-from metrics.registry import LOSSFUNC
-
-from .cross_entropy_loss import CrossEntropyLoss
-from .contrastive_regularization import ContrastiveLoss
-from .l1_loss import L1Loss
\ No newline at end of file
diff --git a/base_miner/DFB/loss/abstract_loss_func.py b/base_miner/DFB/loss/abstract_loss_func.py
deleted file mode 100644
index 45d3324e..00000000
--- a/base_miner/DFB/loss/abstract_loss_func.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import torch.nn as nn
-
-class AbstractLossClass(nn.Module):
- """Abstract class for loss functions."""
- def __init__(self):
- super(AbstractLossClass, self).__init__()
-
- def forward(self, pred, label):
- """
- Args:
- pred: prediction of the model
- label: ground truth label
-
- Return:
- loss: loss value
- """
- raise NotImplementedError('Each subclass should implement the forward method.')
diff --git a/base_miner/DFB/loss/contrastive_regularization.py b/base_miner/DFB/loss/contrastive_regularization.py
deleted file mode 100644
index 8e5bb7c3..00000000
--- a/base_miner/DFB/loss/contrastive_regularization.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import random
-from collections import defaultdict
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from .abstract_loss_func import AbstractLossClass
-from metrics.registry import LOSSFUNC
-
-
-def swap_spe_features(type_list, value_list):
- type_list = type_list.cpu().numpy().tolist()
- # get index
- index_list = list(range(len(type_list)))
-
- # init a dict, where its key is the type and value is the index
- spe_dict = defaultdict(list)
-
- # do for-loop to get spe dict
- for i, one_type in enumerate(type_list):
- spe_dict[one_type].append(index_list[i])
-
- # shuffle the value list of each key
- for keys in spe_dict.keys():
- random.shuffle(spe_dict[keys])
-
- # generate a new index list for the value list
- new_index_list = []
- for one_type in type_list:
- value = spe_dict[one_type].pop()
- new_index_list.append(value)
-
- # swap the value_list by new_index_list
- value_list_new = value_list[new_index_list]
-
- return value_list_new
-
-
-@LOSSFUNC.register_module(module_name="contrastive_regularization")
-class ContrastiveLoss(AbstractLossClass):
- def __init__(self, margin=1.0):
- super().__init__()
- self.margin = margin
-
- def contrastive_loss(self, anchor, positive, negative):
- dist_pos = F.pairwise_distance(anchor, positive)
- dist_neg = F.pairwise_distance(anchor, negative)
- # Compute loss as the distance between anchor and negative minus the distance between anchor and positive
- loss = torch.mean(torch.clamp(dist_pos - dist_neg + self.margin, min=0.0))
- return loss
-
- def forward(self, common, specific, spe_label):
- # prepare
- bs = common.shape[0]
- real_common, fake_common = common.chunk(2)
- ### common real
- idx_list = list(range(0, bs//2))
- random.shuffle(idx_list)
- real_common_anchor = common[idx_list]
- ### common fake
- idx_list = list(range(bs//2, bs))
- random.shuffle(idx_list)
- fake_common_anchor = common[idx_list]
- ### specific
- specific_anchor = swap_spe_features(spe_label, specific)
- real_specific_anchor, fake_specific_anchor = specific_anchor.chunk(2)
- real_specific, fake_specific = specific.chunk(2)
-
- # Compute the contrastive loss of common between real and fake
- loss_realcommon = self.contrastive_loss(real_common, real_common_anchor, fake_common_anchor)
- loss_fakecommon = self.contrastive_loss(fake_common, fake_common_anchor, real_common_anchor)
-
- # Comupte the constrastive loss of specific between real and fake
- loss_realspecific = self.contrastive_loss(real_specific, real_specific_anchor, fake_specific_anchor)
- loss_fakespecific = self.contrastive_loss(fake_specific, fake_specific_anchor, real_specific_anchor)
-
- # Compute the final loss as the sum of all contrastive losses
- loss = loss_realcommon + loss_fakecommon + loss_fakespecific + loss_realspecific
- return loss
\ No newline at end of file
diff --git a/base_miner/DFB/loss/cross_entropy_loss.py b/base_miner/DFB/loss/cross_entropy_loss.py
deleted file mode 100644
index efa7123e..00000000
--- a/base_miner/DFB/loss/cross_entropy_loss.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import torch.nn as nn
-from .abstract_loss_func import AbstractLossClass
-from metrics.registry import LOSSFUNC
-
-
-@LOSSFUNC.register_module(module_name="cross_entropy")
-class CrossEntropyLoss(AbstractLossClass):
- def __init__(self):
- super().__init__()
- self.loss_fn = nn.CrossEntropyLoss()
-
- def forward(self, inputs, targets):
- """
- Computes the cross-entropy loss.
-
- Args:
- inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
- targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
-
- Returns:
- A scalar tensor representing the cross-entropy loss.
- """
- # Compute the cross-entropy loss
- loss = self.loss_fn(inputs, targets)
-
- return loss
\ No newline at end of file
diff --git a/base_miner/DFB/loss/l1_loss.py b/base_miner/DFB/loss/l1_loss.py
deleted file mode 100644
index f2bfdedb..00000000
--- a/base_miner/DFB/loss/l1_loss.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import torch.nn as nn
-from .abstract_loss_func import AbstractLossClass
-from metrics.registry import LOSSFUNC
-
-
-@LOSSFUNC.register_module(module_name="l1loss")
-class L1Loss(AbstractLossClass):
- def __init__(self):
- super().__init__()
- self.loss_fn = nn.L1Loss()
-
- def forward(self, inputs, targets):
- """
- Computes the l1 loss.
- """
- # Compute the l1 loss
- loss = self.loss_fn(inputs, targets)
-
- return loss
\ No newline at end of file
diff --git a/base_miner/DFB/metrics/__init__.py b/base_miner/DFB/metrics/__init__.py
deleted file mode 100644
index 1ed99d5d..00000000
--- a/base_miner/DFB/metrics/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import os
-import sys
-current_file_path = os.path.abspath(__file__)
-parent_dir = os.path.dirname(os.path.dirname(current_file_path))
-project_root_dir = os.path.dirname(parent_dir)
-sys.path.append(parent_dir)
-sys.path.append(project_root_dir)
\ No newline at end of file
diff --git a/base_miner/DFB/metrics/base_metrics_class.py b/base_miner/DFB/metrics/base_metrics_class.py
deleted file mode 100644
index 9722555a..00000000
--- a/base_miner/DFB/metrics/base_metrics_class.py
+++ /dev/null
@@ -1,205 +0,0 @@
-import numpy as np
-from sklearn import metrics
-from collections import defaultdict
-import torch
-import torch.nn as nn
-
-
-def get_accracy(output, label):
- _, prediction = torch.max(output, 1) # argmax
- correct = (prediction == label).sum().item()
- accuracy = correct / prediction.size(0)
- return accuracy
-
-
-def get_prediction(output, label):
- prob = nn.functional.softmax(output, dim=1)[:, 1]
- prob = prob.view(prob.size(0), 1)
- label = label.view(label.size(0), 1)
- #print(prob.size(), label.size())
- datas = torch.cat((prob, label.float()), dim=1)
- return datas
-
-
-def calculate_metrics_for_train(label, output):
- if output.size(1) == 2:
- prob = torch.softmax(output, dim=1)[:, 1]
- else:
- prob = output
-
- # Accuracy
- _, prediction = torch.max(output, 1)
- correct = (prediction == label).sum().item()
- accuracy = correct / prediction.size(0)
-
- # Average Precision
- y_true = label.cpu().detach().numpy()
- y_pred = prob.cpu().detach().numpy()
- ap = metrics.average_precision_score(y_true, y_pred)
-
- # AUC and EER
- try:
- fpr, tpr, thresholds = metrics.roc_curve(label.squeeze().cpu().numpy(),
- prob.squeeze().cpu().numpy(),
- pos_label=1)
- except:
- # for the case when we only have one sample
- return None, None, accuracy, ap
-
- if np.isnan(fpr[0]) or np.isnan(tpr[0]):
- # for the case when all the samples within a batch is fake/real
- auc, eer = None, None
- else:
- auc = metrics.auc(fpr, tpr)
- fnr = 1 - tpr
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
-
- return auc, eer, accuracy, ap
-
-
-# ------------ compute average metrics of batches---------------------
-class Metrics_batch():
- def __init__(self):
- self.tprs = []
- self.mean_fpr = np.linspace(0, 1, 100)
- self.aucs = []
- self.eers = []
- self.aps = []
-
- self.correct = 0
- self.total = 0
- self.losses = []
-
- def update(self, label, output):
- acc = self._update_acc(label, output)
- if output.size(1) == 2:
- prob = torch.softmax(output, dim=1)[:, 1]
- else:
- prob = output
- #label = 1-label
- #prob = torch.softmax(output, dim=1)[:, 1]
- auc, eer = self._update_auc(label, prob)
- ap = self._update_ap(label, prob)
-
- return acc, auc, eer, ap
-
- def _update_auc(self, lab, prob):
- fpr, tpr, thresholds = metrics.roc_curve(lab.squeeze().cpu().numpy(),
- prob.squeeze().cpu().numpy(),
- pos_label=1)
- if np.isnan(fpr[0]) or np.isnan(tpr[0]):
- return -1, -1
-
- auc = metrics.auc(fpr, tpr)
- interp_tpr = np.interp(self.mean_fpr, fpr, tpr)
- interp_tpr[0] = 0.0
- self.tprs.append(interp_tpr)
- self.aucs.append(auc)
-
- # return auc
-
- # EER
- fnr = 1 - tpr
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
- self.eers.append(eer)
-
- return auc, eer
-
- def _update_acc(self, lab, output):
- _, prediction = torch.max(output, 1) # argmax
- correct = (prediction == lab).sum().item()
- accuracy = correct / prediction.size(0)
- # self.accs.append(accuracy)
- self.correct = self.correct+correct
- self.total = self.total+lab.size(0)
- return accuracy
-
- def _update_ap(self, label, prob):
- y_true = label.cpu().detach().numpy()
- y_pred = prob.cpu().detach().numpy()
- ap = metrics.average_precision_score(y_true,y_pred)
- self.aps.append(ap)
-
- return np.mean(ap)
-
- def get_mean_metrics(self):
- mean_acc, std_acc = self.correct/self.total, 0
- mean_auc, std_auc = self._mean_auc()
- mean_err, std_err = np.mean(self.eers), np.std(self.eers)
- mean_ap, std_ap = np.mean(self.aps), np.std(self.aps)
-
- return {'acc':mean_acc, 'auc':mean_auc, 'eer':mean_err, 'ap':mean_ap}
-
- def _mean_auc(self):
- mean_tpr = np.mean(self.tprs, axis=0)
- mean_tpr[-1] = 1.0
- mean_auc = metrics.auc(self.mean_fpr, mean_tpr)
- std_auc = np.std(self.aucs)
- return mean_auc, std_auc
-
- def clear(self):
- self.tprs.clear()
- self.aucs.clear()
- # self.accs.clear()
- self.correct=0
- self.total=0
- self.eers.clear()
- self.aps.clear()
- self.losses.clear()
-
-
-# ------------ compute average metrics of all data ---------------------
-class Metrics_all():
- def __init__(self):
- self.probs = []
- self.labels = []
- self.correct = 0
- self.total = 0
-
- def store(self, label, output):
- prob = torch.softmax(output, dim=1)[:, 1]
- _, prediction = torch.max(output, 1) # argmax
- correct = (prediction == label).sum().item()
- self.correct += correct
- self.total += label.size(0)
- self.labels.append(label.squeeze().cpu().numpy())
- self.probs.append(prob.squeeze().cpu().numpy())
-
- def get_metrics(self):
- y_pred = np.concatenate(self.probs)
- y_true = np.concatenate(self.labels)
- # auc
- fpr, tpr, thresholds = metrics.roc_curve(y_true,y_pred,pos_label=1)
- auc = metrics.auc(fpr, tpr)
- # eer
- fnr = 1 - tpr
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
- # ap
- ap = metrics.average_precision_score(y_true,y_pred)
- # acc
- acc = self.correct / self.total
- return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap}
-
- def clear(self):
- self.probs.clear()
- self.labels.clear()
- self.correct = 0
- self.total = 0
-
-
-# only used to record a series of scalar value
-class Recorder:
- def __init__(self):
- self.sum = 0
- self.num = 0
- def update(self, item, num=1):
- if item is not None:
- self.sum += item * num
- self.num += num
- def average(self):
- if self.num == 0:
- return None
- return self.sum/self.num
- def clear(self):
- self.sum = 0
- self.num = 0
\ No newline at end of file
diff --git a/base_miner/DFB/metrics/registry.py b/base_miner/DFB/metrics/registry.py
deleted file mode 100644
index d13dd4ce..00000000
--- a/base_miner/DFB/metrics/registry.py
+++ /dev/null
@@ -1,20 +0,0 @@
-class Registry(object):
- def __init__(self):
- self.data = {}
-
- def register_module(self, module_name=None):
- def _register(cls):
- name = module_name
- if module_name is None:
- name = cls.__name__
- self.data[name] = cls
- return cls
- return _register
-
- def __getitem__(self, key):
- return self.data[key]
-
-BACKBONE = Registry()
-DETECTOR = Registry()
-TRAINER = Registry()
-LOSSFUNC = Registry()
\ No newline at end of file
diff --git a/base_miner/DFB/metrics/utils.py b/base_miner/DFB/metrics/utils.py
deleted file mode 100644
index 9a81618d..00000000
--- a/base_miner/DFB/metrics/utils.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from sklearn import metrics
-import numpy as np
-
-
-def parse_metric_for_print(metric_dict):
- if metric_dict is None:
- return "\n"
- str = "\n"
- str += "================================ Each dataset best metric ================================ \n"
- for key, value in metric_dict.items():
- if key != 'avg':
- str= str+ f"| {key}: "
- for k,v in value.items():
- str = str + f" {k}={v} "
- str= str+ "| \n"
- else:
- str += "============================================================================================= \n"
- str += "================================== Average best metric ====================================== \n"
- avg_dict = value
- for avg_key, avg_value in avg_dict.items():
- if avg_key == 'dataset_dict':
- for key,value in avg_value.items():
- str = str + f"| {key}: {value} | \n"
- else:
- str = str + f"| avg {avg_key}: {avg_value} | \n"
- str += "============================================================================================="
- return str
-
-
-def get_test_metrics(y_pred, y_true, img_names=None, logger=None):
- def get_video_metrics(image, pred, label):
- result_dict = {}
- new_label = []
- new_pred = []
- # print(image[0])
- # print(pred.shape)
- # print(label.shape)
- for item in np.transpose(np.stack((image, pred, label)), (1, 0)):
-
- s = item[0]
- if '\\' in s:
- parts = s.split('\\')
- else:
- parts = s.split('/')
- a = parts[-2]
- b = parts[-1]
-
- if a not in result_dict:
- result_dict[a] = []
-
- result_dict[a].append(item)
- image_arr = list(result_dict.values())
-
- for video in image_arr:
- pred_sum = 0
- label_sum = 0
- leng = 0
- for frame in video:
- pred_sum += float(frame[1])
- label_sum += int(frame[2])
- leng += 1
- new_pred.append(pred_sum / leng)
- new_label.append(int(label_sum / leng))
- fpr, tpr, thresholds = metrics.roc_curve(new_label, new_pred)
- v_auc = metrics.auc(fpr, tpr)
- fnr = 1 - tpr
- v_eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
- return v_auc, v_eer
-
-
- y_pred = y_pred.squeeze()
-
- # For UCF, where labels for different manipulations are not consistent.
- y_true[y_true >= 1] = 1
- # auc
- fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
- auc = metrics.auc(fpr, tpr)
- # eer
- fnr = 1 - tpr
- eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
- # ap
- ap = metrics.average_precision_score(y_true, y_pred)
- # acc
- prediction_class = (y_pred > 0.5).astype(int)
- correct = (prediction_class == np.clip(y_true, a_min=0, a_max=1)).sum().item()
- acc = correct / len(prediction_class)
-
- return {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap, 'pred': y_pred, 'label': y_true}
diff --git a/base_miner/DFB/networks/__init__.py b/base_miner/DFB/networks/__init__.py
deleted file mode 100644
index f9be2559..00000000
--- a/base_miner/DFB/networks/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import os
-import sys
-current_file_path = os.path.abspath(__file__)
-parent_dir = os.path.dirname(os.path.dirname(current_file_path))
-project_root_dir = os.path.dirname(parent_dir)
-sys.path.append(parent_dir)
-sys.path.append(project_root_dir)
-
-from metrics.registry import BACKBONE
-
-from .xception import Xception
\ No newline at end of file
diff --git a/base_miner/DFB/networks/xception.py b/base_miner/DFB/networks/xception.py
deleted file mode 100644
index 410345c5..00000000
--- a/base_miner/DFB/networks/xception.py
+++ /dev/null
@@ -1,285 +0,0 @@
-'''
-# author: Zhiyuan Yan
-# email: zhiyuanyan@link.cuhk.edu.cn
-# date: 2023-0706
-
-The code is mainly modified from GitHub link below:
-https://github.com/ondyari/FaceForensics/blob/master/classification/network/xception.py
-'''
-
-import os
-import argparse
-import logging
-
-import math
-import torch
-# import pretrainedmodels
-import torch.nn as nn
-import torch.nn.functional as F
-
-import torch.utils.model_zoo as model_zoo
-from torch.nn import init
-from typing import Union
-from metrics.registry import BACKBONE
-
-logger = logging.getLogger(__name__)
-
-
-
-class SeparableConv2d(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
- super(SeparableConv2d, self).__init__()
-
- self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
- stride, padding, dilation, groups=in_channels, bias=bias)
- self.pointwise = nn.Conv2d(
- in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.pointwise(x)
- return x
-
-
-class Block(nn.Module):
- def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
- super(Block, self).__init__()
-
- if out_filters != in_filters or strides != 1:
- self.skip = nn.Conv2d(in_filters, out_filters,
- 1, stride=strides, bias=False)
- self.skipbn = nn.BatchNorm2d(out_filters)
- else:
- self.skip = None
-
- self.relu = nn.ReLU(inplace=True)
- rep = []
-
- filters = in_filters
- if grow_first: # whether the number of filters grows first
- rep.append(self.relu)
- rep.append(SeparableConv2d(in_filters, out_filters,
- 3, stride=1, padding=1, bias=False))
- rep.append(nn.BatchNorm2d(out_filters))
- filters = out_filters
-
- for i in range(reps-1):
- rep.append(self.relu)
- rep.append(SeparableConv2d(filters, filters,
- 3, stride=1, padding=1, bias=False))
- rep.append(nn.BatchNorm2d(filters))
-
- if not grow_first:
- rep.append(self.relu)
- rep.append(SeparableConv2d(in_filters, out_filters,
- 3, stride=1, padding=1, bias=False))
- rep.append(nn.BatchNorm2d(out_filters))
-
- if not start_with_relu:
- rep = rep[1:]
- else:
- rep[0] = nn.ReLU(inplace=False)
-
- if strides != 1:
- rep.append(nn.MaxPool2d(3, strides, 1))
- self.rep = nn.Sequential(*rep)
-
- def forward(self, inp):
- x = self.rep(inp)
-
- if self.skip is not None:
- skip = self.skip(inp)
- skip = self.skipbn(skip)
- else:
- skip = inp
-
- x += skip
- return x
-
-def add_gaussian_noise(ins, mean=0, stddev=0.2):
- noise = ins.data.new(ins.size()).normal_(mean, stddev)
- return ins + noise
-
-
-@BACKBONE.register_module(module_name="xception")
-class Xception(nn.Module):
- """
- Xception optimized for the ImageNet dataset, as specified in
- https://arxiv.org/pdf/1610.02357.pdf
- """
-
- def __init__(self, xception_config):
- """ Constructor
- Args:
- xception_config: configuration file with the dict format
- """
- super(Xception, self).__init__()
- self.num_classes = xception_config["num_classes"]
- self.mode = xception_config["mode"]
- inc = xception_config["inc"]
- dropout = xception_config["dropout"]
-
- # Entry flow
- self.conv1 = nn.Conv2d(inc, 32, 3, 2, 0, bias=False)
-
- self.bn1 = nn.BatchNorm2d(32)
- self.relu = nn.ReLU(inplace=True)
-
- self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
- self.bn2 = nn.BatchNorm2d(64)
- # do relu here
-
- self.block1 = Block(
- 64, 128, 2, 2, start_with_relu=False, grow_first=True)
- self.block2 = Block(
- 128, 256, 2, 2, start_with_relu=True, grow_first=True)
- self.block3 = Block(
- 256, 728, 2, 2, start_with_relu=True, grow_first=True)
-
- # middle flow
- self.block4 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
- self.block5 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
- self.block6 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
- self.block7 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
-
- self.block8 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
- self.block9 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
- self.block10 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
- self.block11 = Block(
- 728, 728, 3, 1, start_with_relu=True, grow_first=True)
-
- # Exit flow
- self.block12 = Block(
- 728, 1024, 2, 2, start_with_relu=True, grow_first=False)
-
- self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
- self.bn3 = nn.BatchNorm2d(1536)
-
- # do relu here
- self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
- self.bn4 = nn.BatchNorm2d(2048)
- # used for iid
- final_channel = 2048
- if self.mode == 'adjust_channel_iid':
- final_channel = 512
- self.mode = 'adjust_channel'
- self.last_linear = nn.Linear(final_channel, self.num_classes)
- if dropout:
- self.last_linear = nn.Sequential(
- nn.Dropout(p=dropout),
- nn.Linear(final_channel, self.num_classes)
- )
-
- self.adjust_channel = nn.Sequential(
- nn.Conv2d(2048, 512, 1, 1),
- nn.BatchNorm2d(512),
- nn.ReLU(inplace=False),
- )
-
- def fea_part1_0(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- return x
-
- def fea_part1_1(self, x):
-
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
-
- return x
-
- def fea_part1(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.conv2(x)
- x = self.bn2(x)
- x = self.relu(x)
-
- return x
-
- def fea_part2(self, x):
- x = self.block1(x)
- x = self.block2(x)
- x = self.block3(x)
-
- return x
-
- def fea_part3(self, x):
- if self.mode == "shallow_xception":
- return x
- else:
- x = self.block4(x)
- x = self.block5(x)
- x = self.block6(x)
- x = self.block7(x)
- return x
-
- def fea_part4(self, x):
- if self.mode == "shallow_xception":
- x = self.block12(x)
- else:
- x = self.block8(x)
- x = self.block9(x)
- x = self.block10(x)
- x = self.block11(x)
- x = self.block12(x)
- return x
-
- def fea_part5(self, x):
- x = self.conv3(x)
- x = self.bn3(x)
- x = self.relu(x)
-
- x = self.conv4(x)
- x = self.bn4(x)
-
- return x
-
- def features(self, input):
- x = self.fea_part1(input)
-
- x = self.fea_part2(x)
- x = self.fea_part3(x)
- x = self.fea_part4(x)
-
- x = self.fea_part5(x)
-
- if self.mode == 'adjust_channel':
- x = self.adjust_channel(x)
-
- return x
-
- def classifier(self, features,id_feat=None):
- # for iid
- if self.mode == 'adjust_channel':
- x = features
- else:
- x = self.relu(features)
-
- if len(x.shape) == 4:
- x = F.adaptive_avg_pool2d(x, (1, 1))
- x = x.view(x.size(0), -1)
- self.last_emb = x
- # for iid
- if id_feat!=None:
- out = self.last_linear(x-id_feat)
- else:
- out = self.last_linear(x)
- return out
-
- def forward(self, input):
- x = self.features(input)
- out = self.classifier(x)
- return out, x
diff --git a/base_miner/DFB/optimizor/LinearLR.py b/base_miner/DFB/optimizor/LinearLR.py
deleted file mode 100644
index 80bc70db..00000000
--- a/base_miner/DFB/optimizor/LinearLR.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import torch
-from torch.optim import SGD
-from torch.optim.lr_scheduler import _LRScheduler
-
-class LinearDecayLR(_LRScheduler):
- def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
- self.start_decay=start_decay
- self.n_epoch=n_epoch
- super(LinearDecayLR, self).__init__(optimizer, last_epoch)
-
- def get_lr(self):
- last_epoch = self.last_epoch
- n_epoch=self.n_epoch
- b_lr=self.base_lrs[0]
- start_decay=self.start_decay
- if last_epoch>start_decay:
- lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
- else:
- lr=b_lr
- return [lr]
\ No newline at end of file
diff --git a/base_miner/DFB/optimizor/SAM.py b/base_miner/DFB/optimizor/SAM.py
deleted file mode 100644
index 7b8d1dc5..00000000
--- a/base_miner/DFB/optimizor/SAM.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# borrowed from
-
-import torch
-
-import torch
-import torch.nn as nn
-
-def disable_running_stats(model):
- def _disable(module):
- if isinstance(module, nn.BatchNorm2d):
- module.backup_momentum = module.momentum
- module.momentum = 0
-
- model.apply(_disable)
-
-def enable_running_stats(model):
- def _enable(module):
- if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
- module.momentum = module.backup_momentum
-
- model.apply(_enable)
-
-class SAM(torch.optim.Optimizer):
- def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
- assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
-
- defaults = dict(rho=rho, **kwargs)
- super(SAM, self).__init__(params, defaults)
-
- self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
- self.param_groups = self.base_optimizer.param_groups
-
- @torch.no_grad()
- def first_step(self, zero_grad=False):
- grad_norm = self._grad_norm()
- for group in self.param_groups:
- scale = group["rho"] / (grad_norm + 1e-12)
-
- for p in group["params"]:
- if p.grad is None: continue
- e_w = p.grad * scale.to(p)
- p.add_(e_w) # climb to the local maximum "w + e(w)"
- self.state[p]["e_w"] = e_w
-
- if zero_grad: self.zero_grad()
-
- @torch.no_grad()
- def second_step(self, zero_grad=False):
- for group in self.param_groups:
- for p in group["params"]:
- if p.grad is None: continue
- p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
-
- self.base_optimizer.step() # do the actual "sharpness-aware" update
-
- if zero_grad: self.zero_grad()
-
- @torch.no_grad()
- def step(self, closure=None):
- assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
- closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
-
- self.first_step(zero_grad=True)
- closure()
- self.second_step()
-
- def _grad_norm(self):
- shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
- norm = torch.norm(
- torch.stack([
- p.grad.norm(p=2).to(shared_device)
- for group in self.param_groups for p in group["params"]
- if p.grad is not None
- ]),
- p=2
- )
- return norm
\ No newline at end of file
diff --git a/base_miner/DFB/train_detector.py b/base_miner/DFB/train_detector.py
deleted file mode 100644
index b5abf663..00000000
--- a/base_miner/DFB/train_detector.py
+++ /dev/null
@@ -1,368 +0,0 @@
-# This script was adapted from the DeepfakeBench training code,
-# originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
-
-# Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
-
-# BitMind's modifications include adding a testing phase, changing the
-# data load/split pipeline to work with subnet 34's image augmentations
-# and datasets from BitMind HuggingFace repositories, quality of life CLI args,
-# logging changes, etc.
-
-import os
-import sys
-import argparse
-from os.path import join
-import random
-import datetime
-import time
-import yaml
-from tqdm import tqdm
-import numpy as np
-from datetime import timedelta
-from copy import deepcopy
-from PIL import Image as pil_image
-from pathlib import Path
-import gc
-
-import torch
-import torch.nn as nn
-import torch.nn.parallel
-import torch.backends.cudnn as cudnn
-import torch.utils.data
-import torch.optim as optim
-from torch.utils.data.distributed import DistributedSampler
-import torch.distributed as dist
-from torch.utils.data import DataLoader
-
-from base_miner.DFB.optimizor.SAM import SAM
-from base_miner.DFB.optimizor.LinearLR import LinearDecayLR
-from base_miner.DFB.config.helpers import save_config
-from base_miner.DFB.trainer.trainer import Trainer
-from base_miner.DFB.detectors import DETECTOR
-from base_miner.DFB.metrics.utils import parse_metric_for_print
-from base_miner.DFB.logger import create_logger, RankFilter
-
-from huggingface_hub import hf_hub_download
-
-# BitMind imports (not from original Deepfake Bench repo)
-from base_miner.datasets.util import load_and_split_datasets, create_real_fake_datasets
-from base_miner.config import VIDEO_DATASETS, IMAGE_DATASETS, FACE_IMAGE_DATASETS
-from bitmind.utils.image_transforms import (
- get_base_transforms,
- get_random_augmentations,
- get_ucf_base_transforms,
- get_tall_base_transforms
-)
-from base_miner.DFB.config.constants import (
- CONFIG_PATHS,
- WEIGHTS_DIR,
- HF_REPOS
-)
-
-TRANSFORM_FNS = {
- 'UCF': get_ucf_base_transforms,
- 'TALL': get_tall_base_transforms
-}
-
-
-parser = argparse.ArgumentParser(description='Process some paths.')
-parser.add_argument('--detector', type=str, choices=['UCF', 'TALL'], required=True, help='Detector name')
-parser.add_argument('--modality', type=str, default='image', choices=['image', 'video'])
-parser.add_argument('--faces_only', dest='faces_only', action='store_true', default=False)
-parser.add_argument('--no-save_ckpt', dest='save_ckpt', action='store_false', default=True)
-parser.add_argument('--no-save_feat', dest='save_feat', action='store_false', default=True)
-parser.add_argument("--ddp", action='store_true', default=False)
-parser.add_argument('--device', type=str, default='cuda',
- help='Specify whether to use CPU or GPU. Defaults to GPU if available.')
-parser.add_argument('--gpu_id', type=int, default=0, help='Specify the GPU ID to use if using GPU. Defaults to 0.')
-parser.add_argument('--workers', type=int, default=os.cpu_count() - 1,
- help='number of workers for data loading')
-parser.add_argument('--epochs', type=int, default=None, help='number of training epochs')
-args = parser.parse_args()
-
-
-def init_seed(config):
- if config['manualSeed'] is None:
- config['manualSeed'] = random.randint(1, 10000)
- random.seed(config['manualSeed'])
- if config['cuda']:
- torch.manual_seed(config['manualSeed'])
- torch.cuda.manual_seed_all(config['manualSeed'])
-
-
-def prepare_datasets(config, logger):
- start_time = log_start_time(logger, "Loading and splitting individual datasets")
-
- fake_datasets = load_and_split_datasets(
- config['dataset_meta']['fake'], modality=config['modality'], split_transforms=config['split_transforms'])
- real_datasets = load_and_split_datasets(
- config['dataset_meta']['real'], modality=config['modality'], split_transforms=config['split_transforms'])
-
- log_finish_time(logger, "Loading and splitting individual datasets", start_time)
-
- start_time = log_start_time(logger, "Creating real fake dataset splits")
- train_dataset, val_dataset, test_dataset, source_label_mapping = create_real_fake_datasets(
- real_datasets,
- fake_datasets,
- source_labels=True, # TODO UCF Only
- group_sources_by_name=True)
-
- log_finish_time(logger, "Creating real fake dataset splits", start_time)
-
- train_loader = torch.utils.data.DataLoader(
- train_dataset,
- batch_size=config['train_batchSize'],
- shuffle=True,
- num_workers=config['workers'],
- drop_last=True,
- collate_fn=train_dataset.collate_fn)
-
- val_loader = torch.utils.data.DataLoader(
- val_dataset,
- batch_size=config['train_batchSize'],
- shuffle=True,
- num_workers=config['workers'],
- drop_last=True,
- collate_fn=val_dataset.collate_fn)
-
- test_loader = torch.utils.data.DataLoader(
- test_dataset,
- batch_size=config['train_batchSize'],
- shuffle=True,
- num_workers=config['workers'],
- drop_last=True,
- collate_fn=train_dataset.collate_fn)
-
- print(f"Train size: {len(train_loader.dataset)}")
- print(f"Validation size: {len(val_loader.dataset)}")
- print(f"Test size: {len(test_loader.dataset)}")
-
- return train_loader, val_loader, test_loader, source_label_mapping
-
-
-def choose_optimizer(model, config):
- opt_name = config['optimizer']['type']
- if opt_name == 'sgd':
- optimizer = optim.SGD(
- params=model.parameters(),
- lr=config['optimizer'][opt_name]['lr'],
- momentum=config['optimizer'][opt_name]['momentum'],
- weight_decay=config['optimizer'][opt_name]['weight_decay']
- )
- elif opt_name == 'adam':
- optimizer = optim.Adam(
- params=model.parameters(),
- lr=config['optimizer'][opt_name]['lr'],
- weight_decay=config['optimizer'][opt_name]['weight_decay'],
- betas=(config['optimizer'][opt_name]['beta1'], config['optimizer'][opt_name]['beta2']),
- eps=config['optimizer'][opt_name]['eps'],
- amsgrad=config['optimizer'][opt_name]['amsgrad'],
- )
- elif opt_name == 'sam':
- optimizer = SAM(
- model.parameters(),
- optim.SGD,
- lr=config['optimizer'][opt_name]['lr'],
- momentum=config['optimizer'][opt_name]['momentum'],
- )
- else:
- raise NotImplementedError('Optimizer {} is not implemented'.format(config['optimizer']))
- return optimizer
-
-
-def choose_scheduler(config, optimizer):
- if config['lr_scheduler'] is None:
- scheduler = None
- elif config['lr_scheduler'] == 'step':
- scheduler = optim.lr_scheduler.StepLR(
- optimizer,
- step_size=config['lr_step'],
- gamma=config['lr_gamma'],
- )
- elif config['lr_scheduler'] == 'cosine':
- scheduler = optim.lr_scheduler.CosineAnnealingLR(
- optimizer,
- T_max=config['lr_T_max'],
- eta_min=config['lr_eta_min'],
- )
- elif config['lr_scheduler'] == 'linear':
- scheduler = LinearDecayLR(
- optimizer,
- config['nEpochs'],
- int(config['nEpochs']/4),
- )
- else:
- raise NotImplementedError('Scheduler {} is not implemented'.format(config['lr_scheduler']))
- return scheduler
-
-
-def choose_metric(config):
- metric_scoring = config['metric_scoring']
- if metric_scoring not in ['eer', 'auc', 'acc', 'ap']:
- raise NotImplementedError('metric {} is not implemented'.format(metric_scoring))
- return metric_scoring
-
-
-def log_start_time(logger, process_name):
- """Log the start time of a process."""
- start_time = time.time()
- logger.info(f"{process_name} Start Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}")
- return start_time
-
-
-def log_finish_time(logger, process_name, start_time):
- """Log the finish time and elapsed time of a process."""
- finish_time = time.time()
- elapsed_time = finish_time - start_time
-
- # Convert elapsed time into hours, minutes, and seconds
- hours, rem = divmod(elapsed_time, 3600)
- minutes, seconds = divmod(rem, 60)
-
- # Log the finish time and elapsed time
- logger.info(f"{process_name} Finish Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(finish_time))}")
- logger.info(f"{process_name} Elapsed Time: {int(hours)} hours, {int(minutes)} minutes, {seconds:.2f} seconds")
-
-
-def main():
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
- gc.collect()
-
- detector_config_path = CONFIG_PATHS[args.detector]
-
- # parse options and load config
- with open(detector_config_path, 'r') as f:
- config = yaml.safe_load(f)
-
- config['log_dir'] = os.getcwd()
- config['device'] = args.device
- config['modality'] = args.modality
- config['workers'] = args.workers
- config['gpu_id'] = args.gpu_id
- if args.epochs:
- config['nEpochs'] = args.epochs
-
- tforms = TRANSFORM_FNS.get(args.detector, None)((256, 256))
- config['split_transforms'] = {
- 'train': tforms,
- 'validation': tforms,
- 'test': tforms
- }
-
- if config['modality'] == 'video':
- config['dataset_meta'] = VIDEO_DATASETS
- elif config['modality'] == 'image':
- if args.faces_only:
- config['dataset_meta'] = FACE_IMAGE_DATASETS
- else:
- config['dataset_meta'] = IMAGE_DATASETS
-
- dataset_names = [item["path"] for datasets in config['dataset_meta'].values() for item in datasets]
- config['train_dataset'] = dataset_names
- config['save_ckpt'] = args.save_ckpt
- config['save_feat'] = args.save_feat
-
- # create logger
- timenow=datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
- outputs_dir = os.path.join(config['log_dir'], 'logs', config['model_name'] + '_' + timenow)
- config['log_dir'] = outputs_dir
-
- os.makedirs(outputs_dir, exist_ok=True)
- logger = create_logger(os.path.join(outputs_dir, 'training.log'))
- logger.info('Save log to {}'.format(outputs_dir))
-
- config['ddp']= args.ddp
-
- # prepare the data loaders
- train_loader, val_loader, test_loader, source_label_mapping = prepare_datasets(config, logger)
- config['specific_task_number'] = len(set(source_label_mapping.values()))
-
- # init seed
- init_seed(config)
-
- # set cudnn benchmark if needed
- if config['cudnn']:
- cudnn.benchmark = True
- if config['ddp']:
- # dist.init_process_group(backend='gloo')
- dist.init_process_group(
- backend='nccl',
- timeout=timedelta(minutes=30)
- )
- logger.addFilter(RankFilter(0))
-
- # download weights if huggingface repo provided.
- # Note: TALL currently skips this and downloads from github
- pretrained_config = config.get('pretrained', {})
- if not isinstance(pretrained_config, str):
- hf_repo = pretrained_config.get('hf_repo')
- weights_filename = pretrained_config.get('filename')
- if hf_repo and weights_filename:
- local_path = Path(WEIGHTS_DIR) / weights_filename
- if not local_path.exists():
- model_path = hf_hub_download(
- repo_id=hf_repo,
- filename=weights_filename,
- local_dir=WEIGHTS_DIR
- )
- logger.info(f"Downloaded {hf_repo}/{weights_filename} to {model_path}")
- else:
- model_path = local_path
- logger.info(f"{model_path} exists, skipping download")
- config['pretrained']['local_path'] = str(model_path)
- else:
- logger.info("Pretrain config is a url, falling back to detector-specific download")
-
- # prepare model and trainer
- model_class = DETECTOR[config['model_name']]
- model = model_class(config).to(config['device'])
-
- optimizer = choose_optimizer(model, config)
- scheduler = choose_scheduler(config, optimizer)
- metric_scoring = choose_metric(config)
- trainer = Trainer(config, model, config['device'], optimizer, scheduler, logger, metric_scoring)
-
- logger.info("--------------- Configuration ---------------")
- params_string = "Parameters: \n"
- for key, value in config.items():
- params_string += "{}: {}".format(key, value) + "\n"
- logger.info(params_string)
-
- # save training configs
- save_config(config, outputs_dir)
-
- # start training
- start_time = log_start_time(logger, "Training")
- for epoch in range(config['start_epoch'], config['nEpochs'] + 1):
- trainer.model.epoch = epoch
- best_metric = trainer.train_epoch(
- epoch,
- train_data_loader=train_loader,
- validation_data_loaders={'val':val_loader}
- )
- if best_metric is not None:
- logger.info(f"===> Epoch[{epoch}] end with validation {metric_scoring}: {parse_metric_for_print(best_metric)}!")
-
- logger.info("Stop Training on best Validation metric {}".format(parse_metric_for_print(best_metric)))
- log_finish_time(logger, "Training", start_time)
-
- # test
- start_time = log_start_time(logger, "Test")
- trainer.eval(eval_data_loaders={'test':test_loader}, eval_stage="test")
- log_finish_time(logger, "Test", start_time)
-
- if scheduler is not None:
- scheduler.step()
-
- # close the tensorboard writers
- for writer in trainer.writers.values():
- writer.close()
-
- torch.cuda.empty_cache()
- gc.collect()
-
-
-if __name__ == '__main__':
- main()
diff --git a/base_miner/DFB/trainer/trainer.py b/base_miner/DFB/trainer/trainer.py
deleted file mode 100644
index d4287c84..00000000
--- a/base_miner/DFB/trainer/trainer.py
+++ /dev/null
@@ -1,439 +0,0 @@
-# This script was adapted from the DeepfakeBench training code,
-# originally authored by Zhiyuan Yan (zhiyuanyan@link.cuhk.edu.cn)
-
-# Original: https://github.com/SCLBD/DeepfakeBench/blob/main/training/train.py
-
-import os
-import sys
-current_file_path = os.path.abspath(__file__)
-parent_dir = os.path.dirname(os.path.dirname(current_file_path))
-project_root_dir = os.path.dirname(parent_dir)
-sys.path.append(parent_dir)
-sys.path.append(project_root_dir)
-
-import pickle
-import datetime
-import logging
-import numpy as np
-from copy import deepcopy
-from collections import defaultdict
-from tqdm import tqdm
-import time
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.optim as optim
-from torch.nn import DataParallel
-from torch.utils.tensorboard import SummaryWriter
-from metrics.base_metrics_class import Recorder
-from torch.optim.swa_utils import AveragedModel, SWALR
-from torch import distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-from sklearn import metrics
-from metrics.utils import get_test_metrics
-
-
-class Trainer(object):
- def __init__(
- self,
- config,
- model,
- device,
- optimizer,
- scheduler,
- logger,
- metric_scoring='auc',
- swa_model=None
- ):
- # check if all the necessary components are implemented
- if config is None or model is None or optimizer is None or logger is None:
- raise ValueError("config, model, optimizier, logger, and tensorboard writer must be implemented")
-
- self.config = config
- self.model = model
- self.device = device
- self.optimizer = optimizer
- self.scheduler = scheduler
- self.swa_model = swa_model
- self.writers = {} # dict to maintain different tensorboard writers for each dataset and metric
- self.logger = logger
- self.metric_scoring = metric_scoring
- # maintain the best metric of all epochs
- self.best_metrics_all_time = defaultdict(
- lambda: defaultdict(lambda: float('-inf')
- if self.metric_scoring != 'eer' else float('inf'))
- )
- self.speed_up()
-
- # create directory path
- self.log_dir = self.config['log_dir']
- print("Making dir ", self.log_dir)
- os.makedirs(self.log_dir, exist_ok=True)
-
- def get_writer(self, phase, dataset_key, metric_key):
- phase = phase.split('/')[-1]
- dataset_key = dataset_key.split('/')[-1]
- metric_key = metric_key.split('/')[-1]
- writer_key = f"{phase}-{dataset_key}-{metric_key}"
- if writer_key not in self.writers:
- # update directory path
- writer_path = os.path.join(
- self.log_dir,
- phase,
- dataset_key,
- metric_key,
- "metric_board"
- )
- os.makedirs(writer_path, exist_ok=True)
- # update writers dictionary
- self.writers[writer_key] = SummaryWriter(writer_path)
- return self.writers[writer_key]
-
- def speed_up(self):
- self.model.device = self.device
- if 'cuda' in str(self.device) and self.config['ddp'] == True and torch.cuda.is_available():
- num_gpus = torch.cuda.device_count()
- print(f'avai gpus: {num_gpus}')
- self.model = DDP(self.model, device_ids=[self.config['gpu_id']],
- find_unused_parameters=True, output_device=self.config['gpu_id'])
-
- def setTrain(self):
- self.model.train()
- self.train = True
-
- def setEval(self):
- self.model.eval()
- self.train = False
-
- def load_ckpt(self, model_path):
- if os.path.isfile(model_path):
- saved = torch.load(model_path, map_location='cpu')
- suffix = model_path.split('.')[-1]
- if suffix == 'p':
- self.model.load_state_dict(saved.state_dict())
- else:
- self.model.load_state_dict(saved)
- self.logger.info('Model found in {}'.format(model_path))
- else:
- raise NotImplementedError(
- "=> no model found at '{}'".format(model_path))
-
- def save_ckpt(self, phase, dataset_key,ckpt_info=None):
- save_dir = self.log_dir
- os.makedirs(save_dir, exist_ok=True)
- ckpt_name = f"ckpt_best.pth"
- save_path = os.path.join(save_dir, ckpt_name)
- if self.config['ddp'] == True:
- torch.save(self.model.state_dict(), save_path)
- else:
- if 'svdd' in self.config['model_name']:
- torch.save({'R': self.model.R,
- 'c': self.model.c,
- 'state_dict': self.model.state_dict(),}, save_path)
- else:
- torch.save(self.model.state_dict(), save_path)
- self.logger.info(f"Checkpoint saved to {save_path}, current ckpt is {ckpt_info}")
-
- def save_swa_ckpt(self):
- save_dir = self.log_dir
- os.makedirs(save_dir, exist_ok=True)
- ckpt_name = f"swa.pth"
- save_path = os.path.join(save_dir, ckpt_name)
- torch.save(self.swa_model.state_dict(), save_path)
- self.logger.info(f"SWA Checkpoint saved to {save_path}")
-
- def save_feat(self, phase, fea, dataset_key):
- save_dir = os.path.join(self.log_dir, phase, dataset_key)
- os.makedirs(save_dir, exist_ok=True)
- features = fea
- feat_name = f"feat_best.npy"
- save_path = os.path.join(save_dir, feat_name)
- np.save(save_path, features)
- self.logger.info(f"Feature saved to {save_path}")
-
- def save_data_dict(self, phase, data_dict, dataset_key):
- save_dir = os.path.join(self.log_dir, phase, dataset_key)
- os.makedirs(save_dir, exist_ok=True)
- file_path = os.path.join(save_dir, f'data_dict_{phase}.pickle')
- with open(file_path, 'wb') as file:
- pickle.dump(data_dict, file)
- self.logger.info(f"data_dict saved to {file_path}")
-
- def save_metrics(self, phase, metric_one_dataset, dataset_key):
- save_dir = os.path.join(self.log_dir, phase, dataset_key)
- os.makedirs(save_dir, exist_ok=True)
- file_path = os.path.join(save_dir, 'metric_dict_best.pickle')
- with open(file_path, 'wb') as file:
- pickle.dump(metric_one_dataset, file)
- self.logger.info(f"Metrics saved to {file_path}")
-
- def train_step(self,data_dict):
- if self.config['optimizer']['type']=='sam':
- for i in range(2):
- predictions = self.model(data_dict)
- losses = self.model.get_losses(data_dict, predictions)
- if i == 0:
- pred_first = predictions
- losses_first = losses
- self.optimizer.zero_grad()
- losses['overall'].backward()
- if i == 0:
- self.optimizer.first_step(zero_grad=True)
- else:
- self.optimizer.second_step(zero_grad=True)
- return losses_first, pred_first
- else:
- predictions = self.model(data_dict)
- if type(self.model) is DDP:
- losses = self.model.module.get_losses(data_dict, predictions)
- else:
- losses = self.model.get_losses(data_dict, predictions)
- self.optimizer.zero_grad()
- losses['overall'].backward()
- self.optimizer.step()
-
- return losses,predictions
-
- def train_epoch(
- self,
- epoch,
- train_data_loader,
- validation_data_loaders=None
- ):
-
- self.logger.info("===> Epoch[{}] start!".format(epoch))
- if epoch>=1:
- times_per_epoch = 2
- else:
- times_per_epoch = 1
-
-
- #times_per_epoch=4
- validation_step = len(train_data_loader) // times_per_epoch # validate 10 times per epoch
- step_cnt = epoch * len(train_data_loader)
-
- # define training recorder
- train_recorder_loss = defaultdict(Recorder)
- train_recorder_metric = defaultdict(Recorder)
-
- for iteration, data_dict in tqdm(enumerate(train_data_loader),total=len(train_data_loader)):
- self.setTrain()
- # if using GPU, move data to GPU
- if 'cuda' in str(self.device):
- for key in data_dict.keys():
- if data_dict[key] is not None and key!='name':
- data_dict[key] = data_dict[key].cuda()
-
- losses, predictions=self.train_step(data_dict)
- # update learning rate
-
- if self.config.get('SWA', False) and epoch>self.config['swa_start']:
- self.swa_model.update_parameters(self.model)
-
- # compute training metric for each batch data
- if type(self.model) is DDP:
- batch_metrics = self.model.module.get_train_metrics(data_dict, predictions)
- else:
- batch_metrics = self.model.get_train_metrics(data_dict, predictions)
-
- # store data by recorder
- ## store metric
- for name, value in batch_metrics.items():
- train_recorder_metric[name].update(value)
- ## store loss
- for name, value in losses.items():
- train_recorder_loss[name].update(value)
-
- # run tensorboard to visualize the training process
- if iteration % 300 == 0 and self.config['gpu_id']==0:
- if self.config.get('SWA', False) and (epoch>self.config['swa_start'] or self.config['dry_run']):
- self.scheduler.step()
- # info for loss
- loss_str = f"Iter: {step_cnt} "
- for k, v in train_recorder_loss.items():
- v_avg = v.average()
- if v_avg == None:
- loss_str += f"training-loss, {k}: not calculated"
- continue
- loss_str += f"training-loss, {k}: {v_avg} "
- # tensorboard-1. loss
- processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
- processed_train_dataset = ','.join(processed_train_dataset)
- writer = self.get_writer('train', processed_train_dataset, k)
- writer.add_scalar(f'train_loss/{k}', v_avg, global_step=step_cnt)
- self.logger.info(loss_str)
- # info for metric
- metric_str = f"Iter: {step_cnt} "
- for k, v in train_recorder_metric.items():
- v_avg = v.average()
- if v_avg == None:
- metric_str += f"training-metric, {k}: not calculated "
- continue
- metric_str += f"training-metric, {k}: {v_avg} "
- # tensorboard-2. metric
- processed_train_dataset = [dataset.split('/')[-1] for dataset in self.config['train_dataset']]
- processed_train_dataset = ','.join(processed_train_dataset)
- writer = self.get_writer('train', processed_train_dataset, k)
- writer.add_scalar(f'train_metric/{k}', v_avg, global_step=step_cnt)
- self.logger.info(metric_str)
-
- # clear recorder.
- # Note we only consider the current 300 samples for computing batch-level loss/metric
- for name, recorder in train_recorder_loss.items(): # clear loss recorder
- recorder.clear()
- for name, recorder in train_recorder_metric.items(): # clear metric recorder
- recorder.clear()
-
- # run validation
- if (step_cnt+1) % validation_step == 0:
- if validation_data_loaders is not None and ((not self.config['ddp']) or (self.config['ddp'] and dist.get_rank() == 0)):
- self.logger.info("===> Validation start!")
- validation_best_metric = self.eval(
- eval_data_loaders=validation_data_loaders,
- eval_stage="validation",
- step=step_cnt,
- epoch=epoch,
- iteration=iteration
- )
- else:
- validation_best_metric = None
-
- step_cnt += 1
-
- for key in data_dict.keys():
- if data_dict[key]!=None and key!='name':
- data_dict[key]=data_dict[key].cpu()
- return validation_best_metric
-
- def get_respect_acc(self,prob,label):
- pred = np.where(prob > 0.5, 1, 0)
- judge = (pred == label)
- zero_num = len(label) - np.count_nonzero(label)
- acc_fake = np.count_nonzero(judge[zero_num:]) / len(judge[zero_num:])
- acc_real = np.count_nonzero(judge[:zero_num]) / len(judge[:zero_num])
- return acc_real,acc_fake
-
- def eval_one_dataset(self, data_loader):
- # define eval recorder
- eval_recorder_loss = defaultdict(Recorder)
- prediction_lists = []
- feature_lists=[]
- label_lists = []
- for i, data_dict in tqdm(enumerate(data_loader),total=len(data_loader)):
- # get data
- if 'label_spe' in data_dict:
- data_dict.pop('label_spe') # remove the specific label
- data_dict['label'] = torch.where(data_dict['label']!=0, 1, 0) # fix the label to 0 and 1 only
- # if using GPU, move data to GPU
- if 'cuda' in str(self.device):
- for key in data_dict.keys():
- if data_dict[key] is not None:
- data_dict[key] = data_dict[key].cuda()
- # model forward without considering gradient computation
- predictions = self.inference(data_dict) #dict with keys cls, feat
- label_lists += list(data_dict['label'].cpu().detach().numpy())
- # Get the predicted class for each sample in the batch
- _, predicted_classes = torch.max(predictions['cls'], dim=1)
- # Convert the predicted class indices to a list and add to prediction_lists
- prediction_lists += predicted_classes.cpu().detach().numpy().tolist()
- feature_lists += list(predictions['feat'].cpu().detach().numpy())
- if type(self.model) is not AveragedModel:
- # compute all losses for each batch data
- if type(self.model) is DDP:
- losses = self.model.module.get_losses(data_dict, predictions)
- else:
- losses = self.model.get_losses(data_dict, predictions)
-
- # store data by recorder
- for name, value in losses.items():
- eval_recorder_loss[name].update(value)
- return eval_recorder_loss, np.array(prediction_lists), np.array(label_lists),np.array(feature_lists)
-
- def save_best(self,epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage):
- best_metric = self.best_metrics_all_time[key].get(self.metric_scoring,
- float('-inf') if self.metric_scoring != 'eer' else float(
- 'inf'))
- # Check if the current score is an improvement
- improved = (metric_one_dataset[self.metric_scoring] > best_metric) if self.metric_scoring != 'eer' else (
- metric_one_dataset[self.metric_scoring] < best_metric)
- if improved:
- # Update the best metric
- self.best_metrics_all_time[key][self.metric_scoring] = metric_one_dataset[self.metric_scoring]
- if key == 'avg':
- self.best_metrics_all_time[key]['dataset_dict'] = metric_one_dataset['dataset_dict']
- # Save checkpoint, feature, and metrics if specified in config
- if eval_stage=='validation' and self.config['save_ckpt']:
- self.save_ckpt(eval_stage, key, f"{epoch}+{iteration}")
- self.save_metrics(eval_stage, metric_one_dataset, key)
- if losses_one_dataset_recorder is not None:
- # info for each dataset
- loss_str = f"dataset: {key} step: {step} "
- for k, v in losses_one_dataset_recorder.items():
- writer = self.get_writer(eval_stage, key, k)
- v_avg = v.average()
- if v_avg == None:
- print(f'{k} is not calculated')
- continue
- # tensorboard-1. loss
- writer.add_scalar(f'{eval_stage}_losses/{k}', v_avg, global_step=step)
- loss_str += f"{eval_stage}-loss, {k}: {v_avg} "
- self.logger.info(loss_str)
- # tqdm.write(loss_str)
- metric_str = f"dataset: {key} step: {step} "
- for k, v in metric_one_dataset.items():
- if k == 'pred' or k == 'label' or k=='dataset_dict':
- continue
- metric_str += f"{eval_stage}-metric, {k}: {v} "
- # tensorboard-2. metric
- writer = self.get_writer(eval_stage, key, k)
- writer.add_scalar(f'{eval_stage}_metrics/{k}', v, global_step=step)
- if 'pred' in metric_one_dataset:
- acc_real, acc_fake = self.get_respect_acc(metric_one_dataset['pred'], metric_one_dataset['label'])
- metric_str += f'{eval_stage}-metric, acc_real:{acc_real}; acc_fake:{acc_fake}'
- writer.add_scalar(f'{eval_stage}_metrics/acc_real', acc_real, global_step=step)
- writer.add_scalar(f'{eval_stage}_metrics/acc_fake', acc_fake, global_step=step)
- self.logger.info(metric_str)
-
- def eval(self, eval_data_loaders, eval_stage, step=None, epoch=None, iteration=None):
- # set model to eval mode
- self.setEval()
-
- # define eval recorder
- losses_all_datasets = {}
- metrics_all_datasets = {}
- best_metrics_per_dataset = defaultdict(dict) # best metric for each dataset, for each metric
- avg_metric = {'acc': 0, 'auc': 0, 'eer': 0, 'ap': 0,'dataset_dict':{}} #'video_auc': 0
- keys = eval_data_loaders.keys()
- for key in keys:
- # compute loss for each dataset
- losses_one_dataset_recorder, predictions_nps, label_nps, feature_nps = self.eval_one_dataset(eval_data_loaders[key])
- losses_all_datasets[key] = losses_one_dataset_recorder
- metric_one_dataset=get_test_metrics(y_pred=predictions_nps,y_true=label_nps, logger=self.logger)
-
- for metric_name, value in metric_one_dataset.items():
- if metric_name in avg_metric:
- avg_metric[metric_name]+=value
- avg_metric['dataset_dict'][key] = metric_one_dataset[self.metric_scoring]
- if type(self.model) is AveragedModel:
- metric_str = f"Iter Final for SWA: "
- for k, v in metric_one_dataset.items():
- metric_str += f"{eval_stage}-metric, {k}: {v} "
- self.logger.info(metric_str)
- continue
- self.save_best(epoch,iteration,step,losses_one_dataset_recorder,key,metric_one_dataset,eval_stage)
-
- if len(keys)>0 and self.config.get('save_avg',False):
- # calculate avg value
- for key in avg_metric:
- if key != 'dataset_dict':
- avg_metric[key] /= len(keys)
- self.save_best(epoch, iteration, step, None, 'avg', avg_metric, eval_stage)
-
- self.logger.info(f'===> {eval_stage} Done!')
- return self.best_metrics_all_time # return all types of mean metrics for determining the best ckpt
-
-
- @torch.no_grad()
- def inference(self, data_dict):
- predictions = self.model(data_dict, inference=True)
- return predictions
\ No newline at end of file
diff --git a/base_miner/NPR/NPR.png b/base_miner/NPR/NPR.png
deleted file mode 100644
index c64a1dcd..00000000
Binary files a/base_miner/NPR/NPR.png and /dev/null differ
diff --git a/base_miner/NPR/README.md b/base_miner/NPR/README.md
deleted file mode 100644
index e89e9025..00000000
--- a/base_miner/NPR/README.md
+++ /dev/null
@@ -1,90 +0,0 @@
-# Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection
-
-
-
-Our base miner code is taken from the [NPR-DeepfakeDetection respository](https://github.com/chuangchuangtan/NPR-DeepfakeDetection). Huge thank you to the authors for their work on their CVPR paper and this codebase!
--- Bitmind Devs
-
-
-
-
-
-
- Beijing Jiaotong University, YanShan University, A*Star
-
-
-
-
-
-
-Reference github repository for the paper [Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection](https://arxiv.org/abs/2312.10461).
-```
-@misc{tan2023rethinking,
- title={Rethinking the Up-Sampling Operations in CNN-based Generative Network for Generalizable Deepfake Detection},
- author={Chuangchuang Tan and Huan Liu and Yao Zhao and Shikui Wei and Guanghua Gu and Ping Liu and Yunchao Wei},
- year={2023},
- eprint={2312.10461},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
-}
-```
-
-## News 🆕
-- `2024/02`: NPR is accepted by CVPR 2024! Congratulations and thanks to my all co-authors!
-
-
-
-## Environment setup
-**Classification environment:**
-We recommend installing the required packages by running the command:
-```sh
-pip install -r requirements.txt
-```
-In order to ensure the reproducibility of the results, we provide the following suggestions:
-- Docker image: nvcr.io/nvidia/tensorflow:21.02-tf1-py3
-- Conda environment: [./pytorch18/bin/python](https://drive.google.com/file/d/16MK7KnPebBZx5yeN6jqJ49k7VWbEYQPr/view)
-- Random seed during testing period: [Random seed](https://github.com/chuangchuangtan/NPR-DeepfakeDetection/blob/b4e1bfa59ec58542ab5b1e78a3b75b54df67f3b8/test.py#L14)
-
-## Getting the data
-Download dataset from [CNNDetection CVPR2020](https://github.com/peterwang512/CNNDetection), [UniversalFakeDetect CVPR2023](https://github.com/Yuheng-Li/UniversalFakeDetect) ([googledrive](https://drive.google.com/drive/folders/1nkCXClC7kFM01_fqmLrVNtnOYEFPtWO-?usp=drive_link)), [DIRE 2023ICCV](https://github.com/ZhendongWang6/DIRE) ([googledrive](https://drive.google.com/drive/folders/1jZE4hg6SxRvKaPYO_yyMeJN_DOcqGMEf?usp=sharing)), [GANGen-Detection](https://github.com/chuangchuangtan/GANGen-Detection) ([googledrive](https://drive.google.com/drive/folders/11E0Knf9J1qlv2UuTnJSOFUjIIi90czSj?usp=sharing)), Diffusion1kStep [googledrive](https://drive.google.com/drive/folders/14f0vApTLiukiPvIHukHDzLujrvJpDpRq?usp=sharing).
-```
-pip install gdown==4.7.1
-
-chmod 777 ./download_dataset.sh
-
-./download_dataset.sh
-```
-
-## Training the model
-```sh
-CUDA_VISIBLE_DEVICES=0 python train.py --name 4class-resnet-car-cat-chair-horse --dataroot {CNNDetection-Path} --classes car,cat,chair,horse --batch_size 32 --delr_freq 10 --lr 0.0002 --niter 50
-```
-
-## Testing the detector
-Modify the dataroot in test.py.
-```sh
-CUDA_VISIBLE_DEVICES=0 python test.py --model_path ./NPR.pth -batch_size {BS}
-```
-
-
-## Acknowledgments
-
-This repository borrows partially from the [CNNDetection](https://github.com/peterwang512/CNNDetection).
diff --git a/base_miner/NPR/config/constants.py b/base_miner/NPR/config/constants.py
deleted file mode 100644
index 50958d8d..00000000
--- a/base_miner/NPR/config/constants.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import os
-
-# Path to the directory containing the constants.py file
-CONFIGS_DIR = os.path.dirname(os.path.abspath(__file__))
-
-# The base directory for NPR-related files, i.e., NPR directory
-NPR_BASE_PATH = os.path.abspath(os.path.join(CONFIGS_DIR, "..")) # Points to dfd-arena/detectors/NPR/
-# Absolute paths for the required files and directories
-WEIGHTS_DIR = os.path.join(NPR_BASE_PATH, "weights/") # Path to pretrained weights directory
-
diff --git a/base_miner/NPR/download_dataset.sh b/base_miner/NPR/download_dataset.sh
deleted file mode 100644
index 1e0c3278..00000000
--- a/base_miner/NPR/download_dataset.sh
+++ /dev/null
@@ -1,52 +0,0 @@
-
-pwd=$(cd $(dirname $0); pwd)
-echo pwd: $pwd
-
-# pip install gdown==4.7.1
-
-mkdir dataset
-cd dataset
-
-# --proxy http://ip:port
-
-
-
-# https://github.com/Yuheng-Li/UniversalFakeDetect
-gdown https://drive.google.com/drive/1nkCXClC7kFM01_fqmLrVNtnOYEFPtWO-' -O ./UniversalFakeDetect --folder
-cd ./UniversalFakeDetect
-ls | xargs -I pa sh -c "tar -zxvf pa; rm pa"
-cd $pwd/dataset
-
-# https://github.com/chuangchuangtan/FreqNet-DeepfakeDetection
-# https://drive.google.com/drive/folders/11E0Knf9J1qlv2UuTnJSOFUjIIi90czSj?usp=sharing
-gdown https://drive.google.com/drive/folders/11E0Knf9J1qlv2UuTnJSOFUjIIi90czSj -O ./GANGen-Detection --folder
-
-cd ./GANGen-Detection
-ls | xargs -I pa sh -c "tar -zxvf pa; rm pa"
-cd $pwd/dataset
-
-# https://github.com/ZhendongWang6/DIRE
-# https://drive.google.com/drive/folders/1tKsOU-6FDdstrrKLPYuZ7RpQwtOSHxUD?usp=sharing
-gdown https://drive.google.com/drive/folders/1tKsOU-6FDdstrrKLPYuZ7RpQwtOSHxUD -O ./DiffusionForensics --folder
-
-cd ./DiffusionForensics
-ls | xargs -I pa sh -c "tar -zxvf pa; rm pa"
-cd $pwd/dataset
-
-# https://github.com/Ekko-zn/AIGCDetectBenchmark
-# https://drive.google.com/drive/folders/1BUv1MT1cm90QN3WTMHLEr8PXBsKGxKC9?usp=sharing
-gdown https://drive.google.com/drive/folders/1BUv1MT1cm90QN3WTMHLEr8PXBsKGxKC9 -O ./AIGCDetect_testset --folder
-zip -s- test.zip -O test_full.zip
-unzip test_full.zip -d ./AIGCDetect_testset
-cd $pwd/dataset
-
-gdown https://drive.google.com/drive/folders/14f0vApTLiukiPvIHukHDzLujrvJpDpRq -O ./Diffusion1kStep --folder
-cd ./Diffusion1kStep
-ls | xargs -I pa sh -c "tar -zxvf pa; rm pa"
-cd $pwd/dataset
-
-
-# https://github.com/peterwang512/CNNDetection
-gdown 'https://drive.google.com/u/0/uc?id=1z_fD3UKgWQyOTZIBbYSaQ-hz4AzUrLC1' -O CNN_synth_testset.zip --continue
-tar -zxvf CNN_synth_testset.zip -C ./ForenSynths
-
diff --git a/base_miner/NPR/networks/__init__.py b/base_miner/NPR/networks/__init__.py
deleted file mode 100755
index e69de29b..00000000
diff --git a/base_miner/NPR/networks/base_model.py b/base_miner/NPR/networks/base_model.py
deleted file mode 100755
index 1e1024d0..00000000
--- a/base_miner/NPR/networks/base_model.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# from pix2pix
-import os
-import torch
-import torch.nn as nn
-from torch.nn import init
-from torch.optim import lr_scheduler
-
-
-class BaseModel(nn.Module):
- def __init__(self, opt):
- super(BaseModel, self).__init__()
- self.opt = opt
- self.total_steps = 0
- self.isTrain = opt.isTrain
- self.lr = opt.lr
- self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
- self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
-
- def save_networks(self, epoch):
- save_filename = 'model_epoch_%s.pth' % epoch
- save_path = os.path.join(self.save_dir, save_filename)
-
- # serialize model and optimizer to dict
- # state_dict = {
- # 'model': self.model.state_dict(),
- # 'optimizer' : self.optimizer.state_dict(),
- # 'total_steps' : self.total_steps,
- # }
-
- torch.save(self.model.state_dict(), save_path)
- print(f'Saving model {save_path}')
-
- # load models from the disk
- def load_networks(self, epoch):
- load_filename = 'model_epoch_%s.pth' % epoch
- load_path = os.path.join(self.save_dir, load_filename)
-
- print('loading the model from %s' % load_path)
- # if you are using PyTorch newer than 0.4 (e.g., built from
- # GitHub source), you can remove str() on self.device
- state_dict = torch.load(load_path, map_location=self.device)
- if hasattr(state_dict, '_metadata'):
- del state_dict._metadata
-
- self.model.load_state_dict(state_dict['model'])
- self.total_steps = state_dict['total_steps']
-
- if self.isTrain and not self.opt.new_optim:
- self.optimizer.load_state_dict(state_dict['optimizer'])
- ### move optimizer state to GPU
- for state in self.optimizer.state.values():
- for k, v in state.items():
- if torch.is_tensor(v):
- state[k] = v.to(self.device)
-
- for g in self.optimizer.param_groups:
- g['lr'] = self.opt.lr
-
- def eval(self):
- self.model.eval()
-
- def train(self):
- self.model.train()
-
- def test(self):
- with torch.no_grad():
- self.forward()
-
-
-def init_weights(net, init_type='normal', gain=0.02):
- def init_func(m):
- classname = m.__class__.__name__
- if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
- if init_type == 'normal':
- init.normal_(m.weight.data, 0.0, gain)
- elif init_type == 'xavier':
- init.xavier_normal_(m.weight.data, gain=gain)
- elif init_type == 'kaiming':
- init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
- elif init_type == 'orthogonal':
- init.orthogonal_(m.weight.data, gain=gain)
- else:
- raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
- if hasattr(m, 'bias') and m.bias is not None:
- init.constant_(m.bias.data, 0.0)
- elif classname.find('BatchNorm2d') != -1:
- init.normal_(m.weight.data, 1.0, gain)
- init.constant_(m.bias.data, 0.0)
-
- print('initialize network with %s' % init_type)
- net.apply(init_func)
diff --git a/base_miner/NPR/networks/resnet.py b/base_miner/NPR/networks/resnet.py
deleted file mode 100755
index 2a5f40ae..00000000
--- a/base_miner/NPR/networks/resnet.py
+++ /dev/null
@@ -1,235 +0,0 @@
-import torch.nn as nn
-import torch.utils.model_zoo as model_zoo
-from torch.nn import functional as F
-from typing import Any, cast, Dict, List, Optional, Union
-import numpy as np
-
-__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
- 'resnet152']
-
-
-model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
-}
-
-
-def conv3x3(in_planes, out_planes, stride=1):
- """3x3 convolution with padding"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=1, bias=False)
-
-
-def conv1x1(in_planes, out_planes, stride=1):
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(BasicBlock, self).__init__()
- self.conv1 = conv3x3(inplanes, planes, stride)
- self.bn1 = nn.BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
-
-
-class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1, downsample=None):
- super(Bottleneck, self).__init__()
- self.conv1 = conv1x1(inplanes, planes)
- self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = conv3x3(planes, planes, stride)
- self.bn2 = nn.BatchNorm2d(planes)
- self.conv3 = conv1x1(planes, planes * self.expansion)
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out += identity
- out = self.relu(out)
-
- return out
-
-
-class ResNet(nn.Module):
-
- def __init__(self, block, layers, num_classes=1, zero_init_residual=False):
- super(ResNet, self).__init__()
-
- self.unfoldSize = 2
- self.unfoldIndex = 0
- assert self.unfoldSize > 1
- assert -1 < self.unfoldIndex and self.unfoldIndex < self.unfoldSize*self.unfoldSize
- self.inplanes = 64
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64 , layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- # self.fc1 = nn.Linear(512 * block.expansion, 1)
- self.fc1 = nn.Linear(512, num_classes)
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif isinstance(m, nn.BatchNorm2d):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
-
- # Zero-initialize the last BN in each residual branch,
- # so that the residual branch starts with zeros, and each residual block behaves like an identity.
- # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
- if zero_init_residual:
- for m in self.modules():
- if isinstance(m, Bottleneck):
- nn.init.constant_(m.bn3.weight, 0)
- elif isinstance(m, BasicBlock):
- nn.init.constant_(m.bn2.weight, 0)
-
- def _make_layer(self, block, planes, blocks, stride=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- conv1x1(self.inplanes, planes * block.expansion, stride),
- nn.BatchNorm2d(planes * block.expansion),
- )
-
- layers = []
- layers.append(block(self.inplanes, planes, stride, downsample))
- self.inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(block(self.inplanes, planes))
-
- return nn.Sequential(*layers)
- def interpolate(self, img, factor):
- return F.interpolate(F.interpolate(img, scale_factor=factor, mode='nearest', recompute_scale_factor=True), scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
- def forward(self, x):
- # n,c,w,h = x.shape
- # if -1*w%2 != 0: x = x[:,:,:w%2*-1,: ]
- # if -1*h%2 != 0: x = x[:,:,: ,:h%2*-1]
- # factor = 0.5
- # x_half = F.interpolate(x, scale_factor=factor, mode='nearest', recompute_scale_factor=True)
- # x_re = F.interpolate(x_half, scale_factor=1/factor, mode='nearest', recompute_scale_factor=True)
- # NPR = x - x_re
- # n,c,w,h = x.shape
- # if w%2 == 1 : x = x[:,:,:-1,:]
- # if h%2 == 1 : x = x[:,:,:,:-1]
- NPR = x - self.interpolate(x, 0.5)
-
- x = self.conv1(NPR*2.0/3.0)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
-
- x = self.avgpool(x)
- x = x.view(x.size(0), -1)
- x = self.fc1(x)
-
- return x
-
-
-def resnet18(pretrained=False, **kwargs):
- """Constructs a ResNet-18 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
- return model
-
-
-def resnet34(pretrained=False, **kwargs):
- """Constructs a ResNet-34 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
- return model
-
-
-def resnet50(pretrained=False, **kwargs):
- """Constructs a ResNet-50 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
- return model
-
-
-def resnet101(pretrained=False, **kwargs):
- """Constructs a ResNet-101 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
- return model
-
-
-def resnet152(pretrained=False, **kwargs):
- """Constructs a ResNet-152 model.
- Args:
- pretrained (bool): If True, returns a model pre-trained on ImageNet
- """
- model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
- if pretrained:
- model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
- return model
diff --git a/base_miner/NPR/networks/trainer.py b/base_miner/NPR/networks/trainer.py
deleted file mode 100755
index 96778238..00000000
--- a/base_miner/NPR/networks/trainer.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import functools
-import torch
-import torch.nn as nn
-from networks.resnet import resnet50, resnet152, resnet101
-from networks.base_model import BaseModel, init_weights
-
-
-class Trainer(BaseModel):
- def name(self):
- return 'Trainer'
-
- def __init__(self, opt):
- super(Trainer, self).__init__(opt)
-
- if self.isTrain and not opt.continue_train:
- self.model = resnet50(pretrained=False, num_classes=1)
-
- if not self.isTrain or opt.continue_train:
- self.model = resnet50(num_classes=1)
-
- if self.isTrain:
- self.loss_fn = nn.BCEWithLogitsLoss()
- # initialize optimizers
- if opt.optim == 'adam':
- self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
- lr=opt.lr, betas=(opt.beta1, 0.999))
- elif opt.optim == 'sgd':
- self.optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()),
- lr=opt.lr, momentum=0.0, weight_decay=0)
- else:
- raise ValueError("optim should be [adam, sgd]")
-
- if not self.isTrain or opt.continue_train:
- self.load_networks(opt.epoch)
- self.model.to(opt.gpu_ids[0])
-
-
- def adjust_learning_rate(self, min_lr=1e-6):
- for param_group in self.optimizer.param_groups:
- param_group['lr'] *= 0.9
- if param_group['lr'] < min_lr:
- return False
- self.lr = param_group['lr']
- print('*'*25)
- print(f'Changing lr from {param_group["lr"]/0.9} to {param_group["lr"]}')
- print('*'*25)
- return True
-
- def set_input(self, batch, device='cuda'):
- keep_idx = [
- i for i, b in enumerate(batch)
- if isinstance(b[0], torch.Tensor) and
- isinstance(batch[0][0], torch.Tensor) and
- b[0].shape[0] == batch[0][0].shape[0]
- ]
- inputs = torch.stack([b[0] for i, b in enumerate(batch) if i in keep_idx])
- labels = torch.stack([torch.tensor(b[1]) for i, b in enumerate(batch) if i in keep_idx])
- self.input, self.label = inputs.to(device).float(), labels.to(device).float()
-
-
- def forward(self):
- self.output = self.model(self.input)
-
- def get_loss(self):
- return self.loss_fn(self.output.squeeze(1), self.label)
-
- def optimize_parameters(self):
- self.forward()
- self.loss = self.loss_fn(self.output.squeeze(1), self.label)
- self.optimizer.zero_grad()
- self.loss.backward()
- self.optimizer.step()
-
diff --git a/base_miner/NPR/options/__init__.py b/base_miner/NPR/options/__init__.py
deleted file mode 100755
index e8275243..00000000
--- a/base_miner/NPR/options/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .train_options import TrainOptions
\ No newline at end of file
diff --git a/base_miner/NPR/options/base_options.py b/base_miner/NPR/options/base_options.py
deleted file mode 100755
index 1f679328..00000000
--- a/base_miner/NPR/options/base_options.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import argparse
-import os
-import time
-import util
-import torch
-#import models
-#import data
-
-
-class BaseOptions():
- def __init__(self):
- self.initialized = False
-
- def initialize(self, parser):
- parser.add_argument('--mode', default='binary')
- parser.add_argument('--arch', type=str, default='res50', help='architecture for binary classification')
-
- # data augmentation
- parser.add_argument('--rz_interp', default='bilinear')
- parser.add_argument('--blur_prob', type=float, default=0)
- parser.add_argument('--blur_sig', default='0.5')
- parser.add_argument('--jpg_prob', type=float, default=0)
- parser.add_argument('--jpg_method', default='cv2')
- parser.add_argument('--jpg_qual', default='75')
-
- parser.add_argument('--dataroot', default='./dataset/', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
- parser.add_argument('--classes', default='', help='image classes to train on')
- parser.add_argument('--class_bal', action='store_true')
- parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
- parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size')
- parser.add_argument('--cropSize', type=int, default=224, help='then crop to this size')
- parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
- parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
- parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
- parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data')
- parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
- parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
- parser.add_argument('--resize_or_crop', type=str, default='scale_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop|none]')
- parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
- parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
- parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
- parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}')
- parser.add_argument('--delr_freq', type=int, default=20, help='frequency of changing lr')
-
-
- self.initialized = True
- return parser
-
- def gather_options(self):
- # initialize parser with basic options
- if not self.initialized:
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser = self.initialize(parser)
-
- # get the basic options
- opt, _ = parser.parse_known_args()
- self.parser = parser
-
- return opt #parser.parse_args()
-
- def print_options(self, opt):
- message = ''
- message += '----------------- Options ---------------\n'
- for k, v in sorted(vars(opt).items()):
- comment = ''
- default = self.parser.get_default(k)
- if v != default:
- comment = '\t[default: %s]' % str(default)
- message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
- message += '----------------- End -------------------'
- print(message)
-
- # save to the disk
-
- expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
- util.mkdirs(expr_dir)
- file_name = os.path.join(expr_dir, 'opt.txt')
- with open(file_name, 'wt') as opt_file:
- opt_file.write(message)
- opt_file.write('\n')
-
- def parse(self, print_options=True):
-
- opt = self.gather_options()
- opt.isTrain = self.isTrain # train or test
- opt.name = opt.name + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
- # process opt.suffix
- if opt.suffix:
- suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
- opt.name = opt.name + suffix
-
- if print_options:
- self.print_options(opt)
-
- # set gpu ids
- str_ids = opt.gpu_ids.split(',')
- opt.gpu_ids = []
- for str_id in str_ids:
- id = int(str_id)
- if id >= 0:
- opt.gpu_ids.append(id)
- if len(opt.gpu_ids) > 0:
- torch.cuda.set_device(opt.gpu_ids[0])
-
- # additional
- opt.classes = opt.classes.split(',')
- opt.rz_interp = opt.rz_interp.split(',')
- opt.blur_sig = [float(s) for s in opt.blur_sig.split(',')]
- opt.jpg_method = opt.jpg_method.split(',')
- opt.jpg_qual = [int(s) for s in opt.jpg_qual.split(',')]
- if len(opt.jpg_qual) == 2:
- opt.jpg_qual = list(range(opt.jpg_qual[0], opt.jpg_qual[1] + 1))
- elif len(opt.jpg_qual) > 2:
- raise ValueError("Shouldn't have more than 2 values for --jpg_qual.")
-
- self.opt = opt
- return self.opt
diff --git a/base_miner/NPR/options/test_options.py b/base_miner/NPR/options/test_options.py
deleted file mode 100755
index 1dba350e..00000000
--- a/base_miner/NPR/options/test_options.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from .base_options import BaseOptions
-
-
-class TestOptions(BaseOptions):
- def initialize(self, parser):
- parser = BaseOptions.initialize(self, parser)
- # parser.add_argument('--dataroot')
- parser.add_argument('--model_path')
- parser.add_argument('--no_resize', action='store_true')
- parser.add_argument('--no_crop', action='store_true')
- parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
- parser.add_argument('--earlystop_epoch', type=int, default=15)
- parser.add_argument('--lr', type=float, default=0.00002, help='initial learning rate for adam')
- parser.add_argument('--niter', type=int, default=0, help='# of iter at starting learning rate')
-
- self.isTrain = False
- return parser
diff --git a/base_miner/NPR/options/train_options.py b/base_miner/NPR/options/train_options.py
deleted file mode 100755
index cf16ee12..00000000
--- a/base_miner/NPR/options/train_options.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from .base_options import BaseOptions
-
-
-class TrainOptions(BaseOptions):
- def initialize(self, parser):
- parser = BaseOptions.initialize(self, parser)
- parser.add_argument('--earlystop_epoch', type=int, default=15)
- parser.add_argument('--data_aug', action='store_true', help='if specified, perform additional data augmentation (photometric, blurring, jpegging)')
- parser.add_argument('--optim', type=str, default='adam', help='optim to use [sgd, adam]')
- parser.add_argument('--new_optim', action='store_true', help='new optimizer instead of loading the optim state')
- parser.add_argument('--loss_freq', type=int, default=400, help='frequency of showing loss on tensorboard')
- parser.add_argument('--save_latest_freq', type=int, default=2000, help='frequency of saving the latest results')
- parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs')
- parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
- parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
- parser.add_argument('--last_epoch', type=int, default=-1, help='starting epoch count for scheduler intialization')
- parser.add_argument('--train_split', type=str, default='train', help='train, val, test, etc')
- parser.add_argument('--val_split', type=str, default='val', help='train, val, test, etc')
- parser.add_argument('--niter', type=int, default=1000, help='# of iter at starting learning rate')
- parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
- parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
-
- # parser.add_argument('--model_path')
- # parser.add_argument('--no_resize', action='store_true')
- # parser.add_argument('--no_crop', action='store_true')
- self.isTrain = True
- return parser
diff --git a/base_miner/NPR/requirements.txt b/base_miner/NPR/requirements.txt
deleted file mode 100755
index 5229c276..00000000
--- a/base_miner/NPR/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-scipy
-scikit-learn
-numpy
-opencv_python
-Pillow
-torch>=1.2.0
-torchvision
diff --git a/base_miner/NPR/test.py b/base_miner/NPR/test.py
deleted file mode 100644
index 07966fbe..00000000
--- a/base_miner/NPR/test.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import sys
-import time
-import os
-import csv
-import torch
-from util import Logger, printSet
-from validate import validate
-from networks.resnet import resnet50
-from options.test_options import TestOptions
-import networks.resnet as resnet
-import numpy as np
-import random
-import random
-def seed_torch(seed=1029):
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.enabled = False
-seed_torch(100)
-DetectionTests = {
- 'ForenSynths': { 'dataroot' : '/opt/data/private/DeepfakeDetection/ForenSynths/',
- 'no_resize' : False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
- 'no_crop' : True,
- },
-
- 'GANGen-Detection': { 'dataroot' : '/opt/data/private/DeepfakeDetection/GANGen-Detection/',
- 'no_resize' : True,
- 'no_crop' : True,
- },
-
- 'DiffusionForensics': { 'dataroot' : '/opt/data/private/DeepfakeDetection/DiffusionForensics/',
- 'no_resize' : False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
- 'no_crop' : True,
- },
-
- 'UniversalFakeDetect': { 'dataroot' : '/opt/data/private/DeepfakeDetection/UniversalFakeDetect/',
- 'no_resize' : False, # Due to the different shapes of images in the dataset, resizing is required during batch detection.
- 'no_crop' : True,
- },
-
- }
-
-
-opt = TestOptions().parse(print_options=False)
-print(f'Model_path {opt.model_path}')
-
-# get model
-model = resnet50(num_classes=1)
-model.load_state_dict(torch.load(opt.model_path, map_location='cpu'), strict=True)
-model.cuda()
-model.eval()
-
-for testSet in DetectionTests.keys():
- dataroot = DetectionTests[testSet]['dataroot']
- printSet(testSet)
-
- accs = [];aps = []
- print(time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()))
- for v_id, val in enumerate(os.listdir(dataroot)):
- opt.dataroot = '{}/{}'.format(dataroot, val)
- opt.classes = '' #os.listdir(opt.dataroot) if multiclass[v_id] else ['']
- opt.no_resize = DetectionTests[testSet]['no_resize']
- opt.no_crop = DetectionTests[testSet]['no_crop']
- acc, ap, _, _, _, _ = validate(model, opt)
- accs.append(acc);aps.append(ap)
- print("({} {:12}) acc: {:.1f}; ap: {:.1f}".format(v_id, val, acc*100, ap*100))
- print("({} {:10}) acc: {:.1f}; ap: {:.1f}".format(v_id+1,'Mean', np.array(accs).mean()*100, np.array(aps).mean()*100));print('*'*25)
-
diff --git a/base_miner/NPR/train_detector.py b/base_miner/NPR/train_detector.py
deleted file mode 100644
index 64d4e608..00000000
--- a/base_miner/NPR/train_detector.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from tensorboardX import SummaryWriter
-from torch.utils.data import DataLoader
-import numpy as np
-import os
-import time
-import random
-import torch
-
-from base_miner.NPR.validate import validate
-from base_miner.NPR.networks.trainer import Trainer
-from base_miner.config import IMAGE_DATASETS as DATASET_META
-from base_miner.NPR.options import TrainOptions
-from bitmind.utils.image_transforms import get_base_transforms, get_random_augmentations
-from base_miner.datasets.util import load_and_split_datasets, create_real_fake_datasets
-
-
-def seed_torch(seed=1029):
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.enabled = False
-
-
-def main():
- opt = TrainOptions().parse()
- seed_torch(100)
-
- train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
- val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))
-
- # RealFakeDataseta will limit the number of images sampled per dataset to the length of the smallest dataset
- base_transforms = get_base_transforms()
- random_augs = get_random_augmentations()
- split_transforms = {
- 'train': random_augs,
- 'validation': base_transforms,
- 'test': base_transforms
- }
- real_datasets = load_and_split_datasets(
- DATASET_META['real'], modality='image', split_transforms=split_transforms)
- fake_datasets = load_and_split_datasets(
- DATASET_META['fake'], modality='image', split_transforms=split_transforms)
- train_dataset, val_dataset, test_dataset = create_real_fake_datasets(
- real_datasets, fake_datasets)
-
- train_loader = DataLoader(
- train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=lambda d: tuple(d))
- val_loader = DataLoader(
- val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))
- test_loader = DataLoader(
- test_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))
-
- model = Trainer(opt)
- display_loss_steps = 10
- early_stopping_epochs = 10
- best_val_acc = 0
- n_epoch_since_improvement = 0
- model.train()
-
- print(f'cwd: {os.getcwd()}')
- for epoch in range(opt.niter):
-
- for step, data in enumerate(train_loader):
- model.set_input(data)
- model.optimize_parameters()
-
- if step % display_loss_steps == 0:
- ts = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
- print(f"{ts} | Step: {step} ({model.total_steps}) | Train loss: {model.loss} | lr {model.lr}")
-
- if model.total_steps % opt.loss_freq == 0:
- train_writer.add_scalar('loss', model.loss, model.total_steps)
-
- model.total_steps += 1
-
- if epoch % opt.delr_freq == 0 and epoch != 0:
- ts = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
- print(ts, 'changing lr at the end of epoch %d, iters %d' % (epoch, model.total_steps))
- model.adjust_learning_rate()
-
- # Validation
- model.eval()
- acc, ap = validate(model.model, val_loader)[:2]
- val_writer.add_scalar('accuracy', acc, model.total_steps)
- val_writer.add_scalar('ap', ap, model.total_steps)
-
- print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
- if acc > best_val_acc:
- model.save_networks('best')
- best_val_acc = acc
- else:
- n_epoch_since_improvement += 1
- if n_epoch_since_improvement >= early_stopping_epochs:
- break
-
- model.train()
-
- model.eval()
- acc, ap = validate(model.model, test_loader)[:2]
- print("(Test) acc: {}; ap: {}".format(acc, ap))
- model.save_networks('last')
-
-
-if __name__ == '__main__':
- main()
-
-
diff --git a/base_miner/NPR/util/__init__.py b/base_miner/NPR/util/__init__.py
deleted file mode 100644
index bfd44f33..00000000
--- a/base_miner/NPR/util/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import sys
-import os
-import torch
-
-
-def mkdirs(paths):
- if isinstance(paths, list) and not isinstance(paths, str):
- for path in paths:
- mkdir(path)
- else:
- mkdir(paths)
-
-
-def mkdir(path):
- if not os.path.exists(path):
- os.makedirs(path)
-
-
-def unnormalize(tens, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
- # assume tensor of shape NxCxHxW
- return tens * torch.Tensor(std)[None, :, None, None] + torch.Tensor(
- mean)[None, :, None, None]
-
-
-class Logger(object):
- """Log stdout messages."""
-
- def __init__(self, outfile):
- self.terminal = sys.stdout
- self.log = open(outfile, "a")
- sys.stdout = self
-
- def write(self, message):
- self.terminal.write(message)
- self.log.write(message)
-
- def flush(self):
- self.terminal.flush()
-
-
-def printSet(set_str):
- set_str = str(set_str)
- num = len(set_str)
- print("=" * num * 3)
- print(" " * num + set_str)
- print("=" * num * 3)
\ No newline at end of file
diff --git a/base_miner/NPR/util/eval.py b/base_miner/NPR/util/eval.py
deleted file mode 100644
index ace74f42..00000000
--- a/base_miner/NPR/util/eval.py
+++ /dev/null
@@ -1,21 +0,0 @@
-import numpy as np
-
-
-def compute_tfpn(y_true, y_pred):
-
- tp = sum(y_pred[y_true==1] > 0.5)
- fp = sum(y_pred[y_true==0] > 0.5)
- tn = sum(y_pred[y_true==0] <= 0.5)
- fn = sum(y_pred[y_true==1] <= 0.5)
- return tp, fp, tn, fn
-
-
-def compute_metrics(TP, FP, TN, FN):
- precision = TP / (TP + FP) if (TP + FP) != 0 else 0
- recall = TP / (TP + FN) if (TP + FN) != 0 else 0
- if (precision + recall) == 0:
- f1_score = 0
- else:
- f1_score = 2 * (precision * recall) / (precision + recall)
-
- return precision, recall, f1_score
\ No newline at end of file
diff --git a/base_miner/NPR/validate.py b/base_miner/NPR/validate.py
deleted file mode 100755
index 82452d56..00000000
--- a/base_miner/NPR/validate.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import torch
-import numpy as np
-from networks.resnet import resnet50
-from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
-from options.test_options import TestOptions
-
-
-def validate(model, dataloader, device='cuda'):
-
- with torch.no_grad():
- y_true, y_pred = [], []
- for batch in dataloader:
- keep_idx = [i for i, b in enumerate(batch) if b[0].shape[0] == batch[0][0].shape[0]]
- # batch = np.array(batch)
- inputs = torch.stack([b[0] for i, b in enumerate(batch) if i in keep_idx])
- labels = torch.stack([torch.tensor(b[1]) for i, b in enumerate(batch) if i in keep_idx])
- img, label = inputs.to(device).float(), labels.to(device).float()
-
- #in_tens = img.cuda()
- out = model(img).sigmoid().flatten().tolist()
- y_pred.extend(out)
- y_true.extend(label.flatten().tolist())
-
- y_true, y_pred = np.array(y_true), np.array(y_pred)
- r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > 0.5)
- f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > 0.5)
- acc = accuracy_score(y_true, y_pred > 0.5)
- ap = average_precision_score(y_true, y_pred)
- return acc, ap, r_acc, f_acc, y_true, y_pred
-
-
-if __name__ == '__main__':
- opt = TestOptions().parse(print_options=False)
-
- model = resnet50(num_classes=1)
- state_dict = torch.load(opt.model_path, map_location='cpu')
- model.load_state_dict(state_dict['model'])
- model.cuda()
- model.eval()
-
- acc, avg_precision, r_acc, f_acc, y_true, y_pred = validate(model, opt)
-
- print("accuracy:", acc)
- print("average precision:", avg_precision)
-
- print("accuracy of real images:", r_acc)
- print("accuracy of fake images:", f_acc)
diff --git a/base_miner/README.md b/base_miner/README.md
deleted file mode 100644
index c6b4c2a9..00000000
--- a/base_miner/README.md
+++ /dev/null
@@ -1,43 +0,0 @@
-## Base Miners
-
-The `base_miner/` directory facilitates the training, orchestration, and deployment of modular and highly customizable deepfake detectors.
-We broadly define **detector** as an algorithm that either employs a single model or orchestrates multiple models to perform the binary real-or-AI inference task. These **models** can be any algorithm that processes an image to determine its classification. This includes not only pretrained machine learning architectures, but also heuristic and statistical modeling frameworks.
-
-## Our Base Miner Detector: Content-Aware Model Orchestration (CAMO)
-
-Read about [CAMO (Content Aware Model Orchestration)](https://bitmindlabs.notion.site/CAMO-Content-Aware-Model-Orchestration-CAMO-Framework-for-Deepfake-Detection-43ef46a0f9de403abec7a577a45cd075), our generalized framework for creating “hard mixture of expert” detectors.
-
-- **Latest Iteration**: The most performant iteration of `class CAMODetector(DeepfakeDetector)` used in our base miner `neurons/miner.py` incorporates a `GatingMechanism(Gate)` that routes to a fine-tuned face expert model and generalist model with the `UCF` architecture.
-
-## Directory Structure
-
-### 1. Architectures and Training
-- **UCF/** and **NPR/**
-
-These folders contain model architectures and training loops for `UCF (ICCV 2023)` and `NPR (CVPR 2024)`, adapted to use curated and preprocessed training datasets on our [BitMind Huggingface](https://huggingface.co/bitmind).
-
-### 2. deepfake_detectors/
-The modular structure for detectors used in the miner neuron is defined here, through `DeepfakeDetector` abstract base class and subclass implementations.
-
-- **deepfake_detectors/** contains:
- - **configs/**: YAML configuration files to load detector instance attributes, including any pretrained model weights.
- - **Abstract Base Class**: A foundational class that outlines the standard structure for implementing detectors.
- - **Detector Subclasses**: Specialized detector implementations that can be dynamically loaded and managed based on configuration.
-
-The `DeepfakeDetector design` allows for high configurability and extension.
-
-### 3. gating_mechanisms/
-Similar to `deepfake_detectors/`, this folder contains abstract base classes and implementations of `Gate`s that are used to handle content-aware preprocessing and routing. This is especially useful for multi-agent detection systems, such as the `DeepfakeDetector` subclass `CAMODetector` in `deepfake_detectors/camo_detector.py`.
-
-- **Abstract Gate Class**: A base class for implementing image content gating.
-- **Gate Subclasses**: These subclasses define specific gating mechanisms responsible for routing inputs to appropriate expert detectors or preprocessing steps based on content characteristics. This is useful for multi-detector or mixture-of-expert detector setups.
-
-### 4. registry.py
-The `registry.py` file is responsible for managing the creation of detectors and gates using a **Factory Method** design pattern. It auto-registers all `DeepfakeDetector` and `Gate` subclasses from their subfolders to respective `Registry` constants, making it simple to instantiate detectors and gates dynamically based on predefined constants.
-
-- **Factory Pattern**: Ensures a clean, maintainable, and scalable method for creating instances of detectors and gating mechanisms.
-- **Auto-Registration**: Automatically registers all available detector and gate subclasses, enabling a flexible and extensible system.
-
-## Integration with `miner.py`
-
-- **Modular Initialization**: The miner neuron in `bitmind-subnet/neurons/miner.py` leverages the registry system to dynamically initialize the detector used for the forward function, facilitating a highly modular design. The detector module used is determined by neuron config args, defaulting to `"CAMO"`.
\ No newline at end of file
diff --git a/base_miner/__init__.py b/base_miner/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/base_miner/config.py b/base_miner/config.py
deleted file mode 100644
index e88e822d..00000000
--- a/base_miner/config.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from pathlib import Path
-
-HUGGINGFACE_CACHE_DIR: Path = Path.home() / '.cache' / 'huggingface'
-TARGET_IMAGE_SIZE = (256, 256)
-
-
-IMAGE_DATASETS = {
- "real": [
- {"path": "bitmind/bm-real"},
- {"path": "bitmind/open-images-v7"},
- {"path": "bitmind/celeb-a-hq"},
- {"path": "bitmind/ffhq-256"},
- {"path": "bitmind/MS-COCO-unique-256"}
- ],
- "fake": [
- {"path": "bitmind/bm-realvisxl"},
- {"path": "bitmind/bm-mobius"},
- {"path": "bitmind/bm-sdxl"}
- ]
-}
-
-# see bitmind-subnet/create_video_dataset_example.sh
-VIDEO_DATASETS = {
- "real": [
- {"path": ""}
- ],
- "fake": [
- {"path": ""}
- ]
-}
-
-FACE_IMAGE_DATASETS = {
- "real": [
- {"path": "bitmind/ffhq-256_training_faces", "name": "base_transforms"},
- {"path": "bitmind/celeb-a-hq_training_faces", "name": "base_transforms"}
-
- ],
- "fake": [
- {"path": "bitmind/ffhq-256___stable-diffusion-xl-base-1.0_training_faces", "name": "base_transforms"},
- {"path": "bitmind/celeb-a-hq___stable-diffusion-xl-base-1.0___256_training_faces", "name": "base_transforms"}
- ]
-}
diff --git a/base_miner/datasets/__init__.py b/base_miner/datasets/__init__.py
deleted file mode 100644
index 78111baa..00000000
--- a/base_miner/datasets/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .base_dataset import BaseDataset
-from .image_dataset import ImageDataset
-from .video_dataset import VideoDataset
-from .real_fake_dataset import RealFakeDataset
diff --git a/base_miner/datasets/base_dataset.py b/base_miner/datasets/base_dataset.py
deleted file mode 100644
index 3dcc8887..00000000
--- a/base_miner/datasets/base_dataset.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from abc import ABC, abstractmethod
-from datasets import Dataset
-from typing import Optional
-from torchvision.transforms import Compose
-
-from base_miner.datasets.download_data import load_huggingface_dataset
-
-
-class BaseDataset(ABC):
- def __init__(
- self,
- huggingface_dataset_path: Optional[str] = None,
- huggingface_dataset_split: str = 'train',
- huggingface_dataset_name: Optional[str] = None,
- huggingface_dataset: Optional[Dataset] = None,
- download_mode: Optional[str] = None,
- transforms: Optional[Compose] = None
- ):
- """Base class for dataset implementations.
-
- Args:
- huggingface_dataset_path (str, optional): Path to the Hugging Face dataset.
- Can be a publicly hosted dataset (/) or
- local directory (imagefolder:)
- huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'.
- huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset.
- huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance.
- download_mode (str, optional): Download mode for the dataset.
- Can be None or "force_redownload"
- """
- self.huggingface_dataset_path = None
- self.huggingface_dataset_split = huggingface_dataset_split
- self.huggingface_dataset_name = None
- self.dataset = None
- self.transforms = transforms
-
- if huggingface_dataset_path is None and huggingface_dataset is None:
- raise ValueError("Either huggingface_dataset_path or huggingface_dataset must be provided.")
-
- # If a dataset is directly provided, use it
- if huggingface_dataset is not None:
- self.dataset = huggingface_dataset
- self.huggingface_dataset_path = self.dataset.info.dataset_name
- self.huggingface_dataset_name = self.dataset.info.config_name
- try:
- self.huggingface_dataset_split = list(self.dataset.info.splits.keys())[0]
- except AttributeError as e:
- self.huggingface_data_split = None
-
- else:
- # Store the initialization parameters
- self.huggingface_dataset_path = huggingface_dataset_path
- self.huggingface_dataset_name = huggingface_dataset_name
- self.dataset = load_huggingface_dataset(
- huggingface_dataset_path,
- huggingface_dataset_split,
- huggingface_dataset_name,
- download_mode)
-
- @abstractmethod
- def __getitem__(self, index: int) -> dict:
- """Get an item from the dataset.
-
- Args:
- index (int): Index of the item to retrieve.
-
- Returns:
- dict: Dictionary containing the item data.
- """
- pass
-
- @abstractmethod
- def __len__(self) -> int:
- """Get the length of the dataset.
-
- Returns:
- int: Length of the dataset.
- """
- pass
diff --git a/base_miner/datasets/create_video_dataset.py b/base_miner/datasets/create_video_dataset.py
deleted file mode 100644
index c834cdf7..00000000
--- a/base_miner/datasets/create_video_dataset.py
+++ /dev/null
@@ -1,305 +0,0 @@
-from collections import defaultdict
-from pathlib import Path
-from typing import Dict, List, Optional, Union, Tuple
-from multiprocessing import Pool, cpu_count
-
-import cv2
-import glob
-import os
-
-import argparse
-from PIL import Image
-from datasets import Dataset, DatasetInfo, Image as HFImage, Split
-from datasets.features import Features, Sequence, Value
-from tqdm import tqdm
-
-
-def process_single_video(args: Tuple[Path, Path, int, Optional[int], bool]) -> Tuple[str, int]:
- """
- Extract frames from a single video
-
- Args:
- args: Tuple containing (video_file, output_dir, frame_rate, max_frames, overwrite)
-
- Returns:
- Tuple of (video_name, number_of_frames_saved)
- """
- video_file, output_dir, frame_rate, max_frames, overwrite = args
- video_name = video_file.stem
- video_output_dir = output_dir / video_name
-
- if video_output_dir.exists() and not overwrite:
- return video_name, 0
-
- video_output_dir.mkdir(parents=True, exist_ok=True)
-
- video_capture = cv2.VideoCapture(str(video_file))
- frame_idx = 0
- saved_frame_count = 0
-
- while True:
- success, frame = video_capture.read()
- if not success or (max_frames is not None and saved_frame_count >= max_frames):
- break
-
- if frame_idx % frame_rate == 0:
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- pil_image = Image.fromarray(frame_rgb)
- frame_filename = video_output_dir / f"frame_{frame_idx:05d}.png"
- pil_image.save(frame_filename)
- saved_frame_count += 1
-
- frame_idx += 1
-
- video_capture.release()
- return video_name, saved_frame_count
-
-
-def extract_frames_from_videos(
- input_dir: Union[str, Path],
- output_dir: Union[str, Path],
- num_videos: Optional[int] = None,
- frame_rate: int = 1,
- max_frames: Optional[int] = None,
- overwrite: bool = False,
- num_workers: Optional[int] = None
-) -> None:
- """
- Extract frames from videos (mp4s -> directories of PILs) using multiprocessing
-
- Args:
- input_dir: Directory containing input MP4 files
- output_dir: Directory where extracted frames will be saved
- num_videos: Number of videos to process. If None, processes all videos
- frame_rate: Extract one frame every 'frame_rate' frames
- max_frames: Maximum number of frames to extract per video
- overwrite: If True, overwrites existing frame directories
- num_workers: Number of worker processes to use. If None, uses CPU count
- """
- input_dir = Path(input_dir)
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- video_files = list(input_dir.glob("*.mp4"))
- if num_videos is not None:
- video_files = video_files[:num_videos]
-
- if not num_workers:
- num_workers = cpu_count()
-
- print(f'Processing {len(video_files)} videos using {num_workers} workers')
-
- # Prepare arguments for each video
- process_args = [
- (video_file, output_dir, frame_rate, max_frames, overwrite)
- for video_file in video_files
- ]
-
- # Process videos in parallel
- with Pool(num_workers) as pool:
- results = list(tqdm(
- pool.imap(process_single_video, process_args),
- total=len(video_files),
- desc="Extracting frames"
- ))
-
- # Print results
- for video_name, frame_count in results:
- if frame_count > 0:
- print(f"Extracted {frame_count} frames from {video_name}")
- else:
- print(f"Skipped {video_name} (already exists)")
-
-
-def create_video_frames_dataset(
- frames_dir: Union[str, Path],
- dataset_name: str = "video_frames",
- validate_frames: bool = False,
- delete_corrupted: bool = False,
-) -> Dataset:
- """Create a HuggingFace dataset from a directory of video frames."""
- frames_dir = Path(frames_dir)
- video_data: Dict[str, Dict[str, List]] = defaultdict(lambda: {'frames': [], 'frame_numbers': []})
-
- for video_dir in tqdm(sorted(os.listdir(frames_dir)), desc='processing video frames'):
- video_path = frames_dir / video_dir
-
- if not video_path.is_dir():
- continue
-
- image_files = []
- for ext in ('*.png', '*.jpg', '*.jpeg'):
- image_files.extend(glob.glob(str(video_path / ext)))
-
- image_files.sort()
-
- # Validate images before adding them to the dataset
- if validate_frames:
- valid_frames = []
- valid_frame_numbers = []
- for img_path in tqdm(image_files, desc="Checking image files"):
- try:
- # Attempt to fully load the image to verify it's valid
- with Image.open(img_path) as img:
- img.load() # Force load the image data
- frame_num = int(Path(img_path).stem.split('_')[1])
- valid_frames.append(img_path)
- valid_frame_numbers.append(frame_num)
- except Exception as e:
- print(f"Skipping corrupted image {img_path}: {str(e)}")
- if delete_corrupted:
- print(f"Deleting {img_path} (delete_corrupted = true)")
- Path(img_path).unlink()
- continue
- if valid_frames: # Only add videos that have valid frames
- video_data[video_dir]['frames'] = valid_frames
- video_data[video_dir]['frame_numbers'] = valid_frame_numbers
- else:
- video_data[video_dir]['frames'] = image_files
- video_data[video_dir]['frame_numbers'] = list(range(len(image_files)))
- print(video_data[video_dir]['frames'][:10])
- print(video_data[video_dir]['frame_numbers'][:10])
-
- dataset_dict = {
- "video_id": [],
- "frames": [],
- "frame_numbers": [],
- "num_frames": []
- }
-
- for video_id, data in video_data.items():
- if data['frames']: # Double check we have frames
- dataset_dict["video_id"].append(video_id)
- dataset_dict["frames"].append(data["frames"])
- dataset_dict["frame_numbers"].append(data["frame_numbers"])
- dataset_dict["num_frames"].append(len(data["frames"]))
-
- features = Features({
- "video_id": Value("string"),
- "frames": Sequence(Value("string")),
- "frame_numbers": Sequence(Value("int64")),
- "num_frames": Value("int64")
- })
-
- dataset_info = DatasetInfo(
- description="Video frames dataset",
- features=features,
- supervised_keys=None,
- homepage="",
- citation="",
- task_templates=None,
- dataset_name=dataset_name
- )
-
- # Create dataset with validated images
- dataset = Dataset.from_dict(
- dataset_dict,
- info=dataset_info,
- features=features
- )
-
- # Convert to HuggingFace image format with error handling
- def convert_frames_to_images(example):
- converted_frames = []
- for frame_path in example["frames"]:
- try:
- converted_frames.append(HFImage().encode_example(frame_path))
- except Exception as e:
- print(f"Error converting {frame_path}: {str(e)}")
- continue
- example["frames"] = converted_frames
- return example
-
- #dataset = dataset.map(convert_frames_to_images)
- return dataset
-
-
-def main() -> None:
- """Parse command line arguments and run the dataset creation pipeline."""
- parser = argparse.ArgumentParser(
- description="Extract frames from videos and create a HuggingFace dataset."
- )
- parser.add_argument(
- "--input_dir",
- type=str,
- required=True,
- help="Path to the directory containing input MP4 files."
- )
- parser.add_argument(
- "--frames_dir",
- type=str,
- required=True,
- help="Path to the directory where extracted frames will be saved."
- )
- parser.add_argument(
- "--dataset_dir",
- type=str,
- required=True,
- help="Path where the HuggingFace dataset will be saved."
- )
- parser.add_argument(
- "--num_videos",
- type=int,
- default=None,
- help="Number of videos to process. If not specified, processes all videos."
- )
- parser.add_argument(
- "--frame_rate",
- type=int,
- default=5,
- help="Extract one frame every 'frame_rate' frames."
- )
- parser.add_argument(
- "--max_frames",
- type=int,
- default=None,
- help="Maximum number of frames to extract per video."
- )
- parser.add_argument(
- "--overwrite",
- action="store_true",
- help="If set, overwrites existing frame directories."
- )
- parser.add_argument(
- "--skip_extraction",
- action="store_true",
- help="If set, skips the frame extraction step and only creates the dataset."
- )
- parser.add_argument(
- "--dataset_name",
- type=str,
- default="video_frames",
- help="Name for the local HuggingFace dataset to be created."
- )
-
- args = parser.parse_args()
-
- if not args.skip_extraction:
- print("Extracting frames from videos...")
- extract_frames_from_videos(
- input_dir=args.input_dir,
- output_dir=args.frames_dir,
- num_videos=args.num_videos,
- frame_rate=args.frame_rate,
- max_frames=args.max_frames,
- overwrite=args.overwrite
- )
-
- print("\nCreating HuggingFace dataset...")
- dataset = create_video_frames_dataset(
- args.frames_dir,
- dataset_name=args.dataset_name
- )
- print(dataset.info)
- print(f"\nSaving dataset to {args.dataset_dir}")
- dataset.save_to_disk(args.dataset_dir)
-
- print(f"\nDataset creation complete!")
- print(f"Total number of videos: {len(dataset)}")
- print(f"Features: {dataset.features}")
- print("Frame counts:", dataset["num_frames"])
- print(f"Dataset name: {dataset.info.dataset_name}")
-
-
-if __name__ == "__main__":
- main()
diff --git a/base_miner/datasets/download_data.py b/base_miner/datasets/download_data.py
deleted file mode 100644
index 485a88ab..00000000
--- a/base_miner/datasets/download_data.py
+++ /dev/null
@@ -1,208 +0,0 @@
-from typing import Optional
-from datasets import load_dataset
-from PIL import Image
-from io import BytesIO
-import datasets
-import argparse
-import time
-import sys
-import os
-import subprocess
-import glob
-import requests
-
-from base_miner.config import IMAGE_DATASETS, HUGGINGFACE_CACHE_DIR
-
-datasets.logging.set_verbosity_warning()
-datasets.disable_progress_bar()
-
-from datasets import load_dataset, load_from_disk
-from typing import Optional
-import os
-
-
-def load_huggingface_dataset(
- path: str,
- split: str = 'train',
- name: Optional[str] = None,
- download_mode: str = 'reuse_cache_if_exists'
-) -> datasets.Dataset:
- """Load a dataset from Hugging Face or a local directory.
-
- Args:
- path (str): Path to dataset. Can be:
- - A Hugging Face dataset path (/)
- - An image folder path (imagefolder:)
- - A local path to a saved dataset (for load_from_disk)
- split (str, optional): Dataset split to load (default: 'train')
- name (str, optional): Dataset configuration name (default: None)
- download_mode (str, optional): Download mode for Hugging Face datasets
-
- Returns:
- Dataset: The loaded dataset or requested split
- """
- # Check if it's a local path suitable for load_from_disk
- if not path.startswith('imagefolder:') and os.path.exists(path):
- try:
- # Look for dataset artifacts that indicate this is a saved dataset
- dataset_files = {'dataset_info.json', 'state.json', 'data'}
- path_contents = set(os.listdir(path))
- if dataset_files.intersection(path_contents):
- dataset = load_from_disk(path)
- if split is None:
- return dataset
- return dataset[split]
- except Exception as e:
- print(f"Attempted load_from_disk but failed: {e}")
-
- if 'imagefolder' in path:
- _, directory = path.split(':')
- if name:
- dataset = load_dataset(path='imagefolder', name=name, data_dir=directory)
- else:
- dataset = load_dataset(path='imagefolder', data_dir=directory)
- else:
- dataset = download_dataset(
- dataset_path=path,
- dataset_name=name,
- download_mode=download_mode,
- cache_dir=HUGGINGFACE_CACHE_DIR)
-
- if split is None:
- return dataset
- return dataset[split]
-
-
-def download_image(url: str) -> Image.Image:
- """Download an image from a URL.
-
- Args:
- url (str): The URL of the image to download.
-
- Returns:
- Image.Image or None: The downloaded image as a PIL Image object if
- successful, otherwise None.
- """
- response = requests.get(url)
- if response.status_code == 200:
- image_data = BytesIO(response.content)
- return Image.open(image_data)
- else:
- #print(f"Failed to download image: {response.status_code}")
- return None
-
-
-def download_dataset(
- dataset_path: str,
- dataset_name: str,
- download_mode: str,
- cache_dir: str,
- max_wait: int = 300
-):
- """Downloads the datasets present in datasets.json with exponential backoff.
-
- Args:
- dataset_path (str): Path to the dataset on Hugging Face
- dataset_name (str): Name/config of the dataset subset
- download_mode (str): Either 'force_redownload' or 'use_cache_if_exists'
- cache_dir (str): Huggingface cache directory. ~/.cache/huggingface by default
- max_wait (int, optional): Maximum wait time between retries in seconds. Defaults to 300.
-
- Returns:
- Dataset: The downloaded Hugging Face dataset
- """
- retry_wait = 10 # initial wait time in seconds
- attempts = 0
- print(f"Downloading {dataset_path} (subset={dataset_name}) dataset...")
- while True:
- try:
- dataset = load_dataset(
- dataset_path,
- name=dataset_name, # config/subset name
- cache_dir=cache_dir,
- download_mode=download_mode,
- trust_remote_code=True)
- break
- except Exception as e:
- print(e)
- if '429' in str(e) or 'ReadTimeoutError' in str(e):
- print(f"Rate limit hit or timeout, retrying in {retry_wait}s...")
- elif isinstance(e, PermissionError):
- file_path = str(e).split(": '")[1].rstrip("'")
- print(f"Permission error at {file_path}, attempting to fix...")
- fix_permissions(file_path) # Attempt to fix permissions directly
- clean_cache(cache_dir) # Clear cache to remove any incomplete or locked files
- else:
- print(f"Unexpected error, stopping retries for {dataset_path}")
- raise e
-
- if retry_wait > max_wait:
- print(f"Download failed for {dataset_path} after {attempts} attempts. Try again later")
- sys.exit(1)
-
- time.sleep(retry_wait)
- retry_wait *= 2 # exponential backoff
- attempts += 1
-
- print(f"Downloaded {dataset_path} dataset to {cache_dir}")
- return dataset
-
-
-def clean_cache(cache_dir):
- """Clears lock files and incomplete downloads from the cache directory.
-
- Args:
- cache_dir (str): Path to the Hugging Face cache directory
- """
- lock_files = glob.glob(os.path.join(cache_dir, "*lock"))
- incomplete_files = glob.glob(os.path.join(cache_dir, "downloads", "**", "*.incomplete"), recursive=True)
- try:
- if lock_files:
- subprocess.run(["rm", *lock_files], check=True)
- if incomplete_files:
- for file in incomplete_files:
- os.remove(file)
- print("Hugging Face cache lock files cleared successfully.")
- except Exception as e:
- print(f"Failed to clear Hugging Face cache lock files: {e}")
-
-
-def fix_permissions(path):
- """Attempts to fix permission issues on a given path.
-
- Args:
- path (str): Path to fix permissions for
- """
- try:
- subprocess.run(["chmod", "-R", "775", path], check=True)
- print(f"Fixed permissions for {path}")
- except subprocess.CalledProcessError as e:
- print(f"Failed to fix permissions for {path}: {e}")
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Download Hugging Face datasets for validator challenge generation and miner training.')
- parser.add_argument('--force_redownload', action='store_true', help='force redownload of datasets')
- parser.add_argument('--modality', default='image', choices=['video', 'image'], help='download image or video datasets')
- parser.add_argument('--cache_dir', type=str, default=HUGGINGFACE_CACHE_DIR, help='huggingface cache directory')
- args = parser.parse_args()
-
- download_mode = "reuse_cache_if_exists"
- if args.force_redownload:
- download_mode = "force_redownload"
-
- os.makedirs(args.cache_dir, exist_ok=True)
- clean_cache(args.cache_dir) # Clear the cache of lock and incomplete files.
-
- if args.modality == 'image':
- dataset_meta = IMAGE_DATASETS
- #elif args.modality == 'video':
- # dataset_meta = VIDEO_DATASET_META
-
- for dataset_type in dataset_meta:
- for dataset in dataset_meta[dataset_type]:
- download_dataset(
- dataset_path=dataset['path'],
- dataset_name=dataset.get('name', None),
- download_mode=download_mode,
- cache_dir=args.cache_dir)
diff --git a/base_miner/datasets/image_dataset.py b/base_miner/datasets/image_dataset.py
deleted file mode 100644
index 09aa4897..00000000
--- a/base_miner/datasets/image_dataset.py
+++ /dev/null
@@ -1,113 +0,0 @@
-from typing import Optional
-from datasets import Dataset
-from PIL import Image
-from io import BytesIO
-from torchvision.transforms import Compose
-
-from .download_data import download_image
-from .base_dataset import BaseDataset
-
-
-class ImageDataset(BaseDataset):
- def __init__(
- self,
- huggingface_dataset_path: Optional[str] = None,
- huggingface_dataset_split: str = 'train',
- huggingface_dataset_name: Optional[str] = None,
- huggingface_dataset: Optional[Dataset] = None,
- download_mode: Optional[str] = None,
- transforms: Optional[Compose] = None,
- ):
- """Initialize the ImageDataset.
-
- Args:
- huggingface_dataset_path (str, optional): Path to the Hugging Face dataset.
- Can be a publicly hosted dataset (/) or
- local directory (imagefolder:)
- huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'.
- huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset.
- huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance.
- download_mode (str, optional): Download mode for the dataset.
- Can be None or "force_redownload"
- """
- super().__init__(
- huggingface_dataset_path=huggingface_dataset_path,
- huggingface_dataset_split=huggingface_dataset_split,
- huggingface_dataset_name=huggingface_dataset_name,
- huggingface_dataset=huggingface_dataset,
- download_mode=download_mode,
- transforms=transforms
- )
-
- def __getitem__(self, index: int) -> dict:
- """
- Get an item (image and ID) from the dataset.
-
- Args:
- index (int): Index of the item to retrieve.
-
- Returns:
- dict: Dictionary containing 'image' (PIL image) and 'id' (str).
- """
- """
- Load an image from self.dataset. Expects self.dataset[i] to be a dictionary containing either 'image' or 'url'
- as a key.
- - The value associated with the 'image' key should be either a PIL image or a b64 string encoding of
- the image.
- - The value associated with the 'url' key should be a url that hosts the image (as in
- dalle-mini/open-images)
-
- Args:
- index (int): Index of the image in the dataset.
-
- Returns:
- dict: Dictionary containing 'image' (PIL image) and 'id' (str).
- """
- sample = self.dataset[int(index)]
- if 'url' in sample:
- image = download_image(sample['url'])
- image_id = sample['url']
- elif 'image_url' in sample:
- image = download_image(sample['image_url'])
- image_id = sample['image_url']
- elif 'image' in sample:
- if isinstance(sample['image'], Image.Image):
- image = sample['image']
- elif isinstance(sample['image'], bytes):
- image = Image.open(BytesIO(sample['image']))
- else:
- raise NotImplementedError
-
- image_id = ''
- if 'name' in sample:
- image_id = sample['name']
- elif 'filename' in sample:
- image_id = sample['filename']
-
- image_id = image_id if image_id != '' else index
-
- else:
- raise NotImplementedError
-
- # remove alpha channel if download didnt 404
- if image is not None:
- image = image.convert('RGB')
-
- if self.transforms is not None:
- image = self.transforms(image)
-
- return {
- 'image': image,
- 'id': image_id,
- 'source': self.huggingface_dataset_path
- }
-
- def __len__(self) -> int:
- """
- Get the length of the dataset.
-
- Returns:
- int: Length of the dataset.
- """
- return len(self.dataset)
-
diff --git a/base_miner/datasets/real_fake_dataset.py b/base_miner/datasets/real_fake_dataset.py
deleted file mode 100644
index 0fb320a8..00000000
--- a/base_miner/datasets/real_fake_dataset.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import numpy as np
-from torchvision import transforms as T
-import torch
-
-class RealFakeDataset:
-
- def __init__(
- self,
- real_image_datasets: list,
- fake_image_datasets: list,
- fake_prob=0.5,
- source_label_mapping=None
- ):
- """
- Initialize the RealFakeDataset instance.
-
- Args:
- real_image_datasets (list): List of ImageDataset objects containing real images
- fake_image_datasets (list): List of ImageDataset objects containing real images
- transforms (transforms.Compose): Image transformations (default: None).
- fake_prob (float): Probability of selecting a fake image (default: 0.5).
- source_label_mapping (dict): A dictionary mapping dataset names to float labels.
- """
- self.real_image_datasets = real_image_datasets
- self.fake_image_datasets = fake_image_datasets
- self.fake_prob = fake_prob
- self.source_label_mapping = source_label_mapping
-
- self._history = {
- 'source': [],
- 'index': [],
- 'label': [],
- }
-
- def __getitem__(self, index: int) -> tuple:
- """
- Retrieve an item (image, label) from the dataset.
- By default, 50/50 chance of real or fake. This can be overidden by self.fake_prob
-
- Args:
- index (int): Index of the item to retrieve.
-
- Returns:
- tuple: Tuple containing the image, its label (1 : fake, 0 : real),
- and its source label (0 for real datasets and >= 1 for fake datasets).
- """
- if len(self._history['index']) > index:
- self.reset()
-
- if np.random.rand() > self.fake_prob:
- source = self.fake_image_datasets[np.random.randint(0, len(self.fake_image_datasets))]
- image = source[index]['image']
- label = 1.0
- else:
- source = self.real_image_datasets[np.random.randint(0, len(self.real_image_datasets))]
- #imgs, idx = source.sample(1)
- image = source[index]['image']
- #image = imgs[0]['image']
- #index = idx[0]
- label = 0.0
-
- self._history['source'].append(source.huggingface_dataset_path)
- self._history['label'].append(label)
- self._history['index'].append(index)
-
- if self.source_label_mapping:
- source_label = self.source_label_mapping[source.huggingface_dataset_path]
- return image, label, source_label
-
- return image, label
-
- def __len__(self) -> int:
- """
- Return the length of the dataset.
-
- Returns:
- int: Length of the dataset (minimum length between fake and real datasets, which limits the number of
- images sampled each epoch to the length of the smallest dataset to avoid imbalance).
- """
- real_dataset_min = min([len(ds) for ds in self.real_image_datasets])
- fake_dataset_min = min([len(ds) for ds in self.fake_image_datasets])
- return min(fake_dataset_min, real_dataset_min)
-
- def reset(self):
- self._history = {
- 'source': [],
- 'index': [],
- 'label': [],
- }
-
- @staticmethod
- def collate_fn(batch):
- images, labels, source_labels = zip(*batch)
-
- images = torch.stack(images, dim=0) # Stack image tensors into a single tensor
- labels = torch.LongTensor(labels)
- source_labels = torch.LongTensor(source_labels)
-
- data_dict = {
- 'image': images,
- 'label': labels,
- 'label_spe': source_labels,
- 'landmark': None,
- 'mask': None
- }
- return data_dict
\ No newline at end of file
diff --git a/base_miner/datasets/util.py b/base_miner/datasets/util.py
deleted file mode 100644
index 7ce46550..00000000
--- a/base_miner/datasets/util.py
+++ /dev/null
@@ -1,174 +0,0 @@
-from typing import List, Tuple, Dict
-import torchvision.transforms as transforms
-import numpy as np
-import datasets
-import datasets
-
-from base_miner.datasets.download_data import load_huggingface_dataset
-from base_miner.datasets import ImageDataset, VideoDataset, RealFakeDataset
-
-datasets.logging.set_verbosity_error()
-datasets.disable_progress_bar()
-
-
-def split_dataset(dataset):
- # Split data into train, validation, test and return the three splits
- dataset = dataset.shuffle(seed=42)
-
- if 'train' in dataset:
- dataset = dataset['train']
-
- split_dataset = {}
- train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
- split_dataset['train'] = train_test_split['train']
- temp_dataset = train_test_split['test']
-
- # Split the temporary dataset into validation and test
- val_test_split = temp_dataset.train_test_split(test_size=0.5, seed=42)
- split_dataset['validation'] = val_test_split['train']
- split_dataset['test'] = val_test_split['test']
-
- return split_dataset['train'], split_dataset['validation'], split_dataset['test']
-
-
-def load_and_split_datasets(
- dataset_meta: list,
- modality: str,
- split_transforms: Dict[str, transforms.Compose] = {},
-) -> Dict[str, List[ImageDataset]]:
- """
- Helper function to load and split dataset into train, validation, and test sets.
-
- Args:
- dataset_meta: List containing metadata about the dataset to load.
-
- Returns:
- A dictionary with keys == "train", "validation", or "test" strings,
- and values == List[ImageDataset].
-
- Dict[str, List[ImageDataset]]
-
- e.g. given two dataset paths in dataset_meta,
- {'train': [, ],
- 'validation': [, ],
- 'test': [, ]}
- """
- splits = ['train', 'validation', 'test']
- datasets = {split: [] for split in splits}
-
- for meta in dataset_meta:
- dataset = load_huggingface_dataset(meta['path'], None, meta.get('name'))
- train_ds, val_ds, test_ds = split_dataset(dataset)
-
- for split, data in zip(splits, [train_ds, val_ds, test_ds]):
- if modality == 'image':
- image_dataset = ImageDataset(huggingface_dataset=data, transforms=split_transforms.get(split, None))
- elif modality == 'video':
- image_dataset = VideoDataset(huggingface_dataset=data, transforms=split_transforms.get(split, None))
- else:
- raise NotImplementedError(f'Unsupported modality: {modality}')
- datasets[split].append(image_dataset)
-
- split_lengths = ', '.join([f"{split} len={len(datasets[split][0])}" for split in splits])
- print(f'done, {split_lengths}')
-
- return datasets
-
-
-def create_source_label_mapping(
- real_datasets: Dict[str, List[ImageDataset]],
- fake_datasets: Dict[str, List[ImageDataset]],
- group_by_name: bool = False
- ) -> Dict:
-
- source_label_mapping = {}
- grouped_source_labels = {}
- # Iterate through real datasets and set their source label to 0.0
- for split, dataset_list in real_datasets.items():
-
- for dataset in dataset_list:
- source = dataset.huggingface_dataset_path
- if source not in source_label_mapping.keys():
- source_label_mapping[source] = 0.0
-
- # Assign incremental labels to fake datasets
- for split, dataset_list in fake_datasets.items():
- for dataset in dataset_list:
- source = dataset.huggingface_dataset_path
- if group_by_name and '__' in source:
- model_name = source.split('__')[1]
- if model_name in grouped_source_labels:
- fake_source_label = grouped_source_labels[model_name]
- else:
- fake_source_label = max(source_label_mapping.values()) + 1
- grouped_source_labels[model_name] = fake_source_label
-
- if source not in source_label_mapping:
- source_label_mapping[source] = fake_source_label
- else:
- if source not in source_label_mapping:
- source_label_mapping[source] = max(source_label_mapping.values()) + 1
-
- return source_label_mapping
-
-
-def create_real_fake_datasets(
- real_datasets: Dict[str, List[ImageDataset]],
- fake_datasets: Dict[str, List[ImageDataset]],
- source_labels: bool = False,
- group_sources_by_name: bool = False) -> Tuple[RealFakeDataset, ...]:
- """
- Args:
- real_datasets: Dict containing train, val, and test keys. Each key maps to a list of ImageDatasets
- fake_datasets: Dict containing train, val, and test keys. Each key maps to a list of ImageDatasets
- train_transforms: transforms to apply to training dataset
- val_transforms: transforms to apply to val dataset
- test_transforms: transforms to apply to test dataset
- Returns:
- Train, val, and test RealFakeDatasets
-
- """
- source_label_mapping = None
- if source_labels:
- source_label_mapping = create_source_label_mapping(
- real_datasets, fake_datasets, group_sources_by_name)
-
- print(f"Source label mapping: {source_label_mapping}")
-
- train_dataset = RealFakeDataset(
- real_image_datasets=real_datasets['train'],
- fake_image_datasets=fake_datasets['train'],
- source_label_mapping=source_label_mapping)
-
- val_dataset = RealFakeDataset(
- real_image_datasets=real_datasets['validation'],
- fake_image_datasets=fake_datasets['validation'],
- source_label_mapping=source_label_mapping)
-
- test_dataset = RealFakeDataset(
- real_image_datasets=real_datasets['test'],
- fake_image_datasets=fake_datasets['test'],
- source_label_mapping=source_label_mapping)
-
- if source_labels:
- return train_dataset, val_dataset, test_dataset, source_label_mapping
- return train_dataset, val_dataset, test_dataset
-
-
-def sample_dataset_index_name(image_datasets: list) -> tuple[int, str]:
- """
- Randomly selects a dataset index from the provided dataset list and returns the index and source name.
-
- Parameters
- ----------
- image_datasets : list
- A list of dataset objects to select from.
-
- Returns
- -------
- tuple[int, str]
- A tuple containing the index of the randomly selected dataset and the source name.
- """
- dataset_index = np.random.randint(0, len(image_datasets))
- source_name = image_datasets[dataset_index].huggingface_dataset_path
- return dataset_index, source_name
diff --git a/base_miner/datasets/video_dataset.py b/base_miner/datasets/video_dataset.py
deleted file mode 100644
index 814e2bbc..00000000
--- a/base_miner/datasets/video_dataset.py
+++ /dev/null
@@ -1,116 +0,0 @@
-"""
-Author: Zhiyuan Yan
-Email: zhiyuanyan@link.cuhk.edu.cn
-Date: 2023-03-30
-Description: Abstract Base Class for all types of deepfake datasets.
-"""
-
-import os
-import cv2
-from PIL import Image
-import sys
-import yaml
-import numpy as np
-from copy import deepcopy
-import random
-import torch
-from torch import nn
-from torch.utils import data
-from torchvision.utils import save_image
-from torchvision.transforms import Compose
-from einops import rearrange
-from typing import List, Tuple, Optional
-from datasets import Dataset
-
-from .base_dataset import BaseDataset
-
-
-class VideoDataset(BaseDataset):
- def __init__(
- self,
- huggingface_dataset_path: Optional[str] = None,
- huggingface_dataset_split: str = 'train',
- huggingface_dataset_name: Optional[str] = None,
- huggingface_dataset: Optional[Dataset] = None,
- download_mode: Optional[str] = None,
- max_frames_per_video: Optional[int] = 4,
- transforms: Optional[Compose] = None
- ):
- """Initialize the ImageDataset.
-
- Args:
- huggingface_dataset_path (str, optional): Path to the Hugging Face dataset.
- Can be a publicly hosted dataset (/) or
- local directory (imagefolder:)
- huggingface_dataset_split (str): Dataset split to load. Defaults to 'train'.
- huggingface_dataset_name (str, optional): Name of the specific Hugging Face dataset subset.
- huggingface_dataset (Dataset, optional): Pre-loaded Hugging Face dataset instance.
- download_mode (str, optional): Download mode for the dataset.
- Can be None or "force_redownload"
- """
- super().__init__(
- huggingface_dataset_path=huggingface_dataset_path,
- huggingface_dataset_split=huggingface_dataset_split,
- huggingface_dataset_name=huggingface_dataset_name,
- huggingface_dataset=huggingface_dataset,
- download_mode=download_mode,
- transforms=transforms,
- )
- self.max_frames = max_frames_per_video
-
- def __getitem__(self, index):
- """Return the data point at the given index.
-
- Args:
- index (int): The index of the data point.
- no_norm (bool): Whether to skip normalization.
-
- Returns:
- tuple: Contains image tensor, label tensor, landmark tensor,
- and mask tensor.
- """
- image_paths = self.dataset[index]['frames']
-
- if not isinstance(image_paths, list):
- image_paths = [image_paths]
-
- images = []
- for image_path in image_paths[:self.max_frames]:
- try:
- img = Image.open(image_path)
- images.append(img)
- except Exception as e:
- print(f"Error loading image at index {index}: {e}")
- return self.__getitem__(0)
-
- if self.transforms is not None:
- images = self.transforms(images)
-
- # Stack images along the time dimension (frame_dim)
- image_tensors = torch.stack(images, dim=0) # Shape: [frame_dim, C, H, W]
-
- frames, channels, height, width = image_tensors.shape
- x = torch.randint(0, width, (1,)).item()
- y = torch.randint(0, height, (1,)).item()
- mask_grid_size = 16
- x1 = max(x - mask_grid_size // 2, 0)
- x2 = min(x + mask_grid_size // 2, width)
- y1 = max(y - mask_grid_size // 2, 0)
- y2 = min(y + mask_grid_size // 2, height)
- image_tensors[:, :, y1:y2, x1:x2] = -1
-
- return {
- 'image': image_tensors, # Shape: [frame_dim, C, H, W]
- 'id': self.dataset[index]['video_id'],
- 'source': self.huggingface_dataset_path
- }
-
-
- def __len__(self) -> int:
- """
- Get the length of the dataset.
-
- Returns:
- int: Length of the dataset.
- """
- return len(self.dataset['video_id'])
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/__init__.py b/base_miner/deepfake_detectors/__init__.py
deleted file mode 100644
index 37529bfa..00000000
--- a/base_miner/deepfake_detectors/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .deepfake_detector import DeepfakeDetector
-from .npr_detector import NPRImageDetector
-from .ucf_detector import UCFImageDetector
-from .camo_detector import CAMOImageDetector
-from .tall_detector import TALLVideoDetector
diff --git a/base_miner/deepfake_detectors/camo_detector.py b/base_miner/deepfake_detectors/camo_detector.py
deleted file mode 100644
index 8408ce3a..00000000
--- a/base_miner/deepfake_detectors/camo_detector.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from pathlib import Path
-import yaml
-import torch
-from PIL import Image
-from base_miner.registry import DETECTOR_REGISTRY
-from base_miner.gating_mechanisms import GatingMechanism
-from base_miner.deepfake_detectors import DeepfakeDetector
-
-
-@DETECTOR_REGISTRY.register_module(module_name='CAMO')
-class CAMOImageDetector(DeepfakeDetector):
- """
- This DeepfakeDetector subclass implements Content-Aware Model Orchestration
- (CAMO), a mixture-of-experts approach to the binary classification of
- real and fake images, breaking the classification problem into content-specific
- subproblems.
-
- The subproblems are solved by using a GatingMechanism to route image
- content to appropriate DeepfakeDetector subclass instance(s) that
- initialize models pretrained to handle the content type.
-
- Attributes:
- model_name (str): Name of the detector instance.
- config_name (str): Name of the YAML file in deepfake_detectors/config/ to load
- attributes from.
- device (str): The type of device ('cpu' or 'cuda').
- """
-
- def __init__(self, model_name: str = 'CAMO', config_name: str = 'camo.yaml', device: str = 'cpu'):
- """
- Initialize the CAMODetector with dynamic model selection based on config.
- """
- self.detectors = {}
- super().__init__(model_name, config_name, device)
-
- gate_names = [
- content_type for content_type in self.content_type
- if self.content_type[content_type].get('use_gate', False)
- ]
- self.gating_mechanism = GatingMechanism(gate_names)
-
- def load_model(self):
- """
- Load detectors dynamically based on the provided configuration and registry.
- """
- for content_type, detector_info in self.content_type.items():
- model_name = detector_info['model']
- detector_config = detector_info['detector_config']
-
- if model_name in DETECTOR_REGISTRY:
- self.detectors[content_type] = DETECTOR_REGISTRY[model_name](
- model_name=f'{model_name}_{content_type.capitalize()}',
- config_name=detector_config,
- device=self.device
- )
- else:
- raise ValueError(f"Detector {model_name} not found in the registry for {content_type}.")
-
- def __call__(
- self, image: Image
- ) -> float:
- """
- Perform inference using the CAMO detector.
-
- Args:
- image (PIL.Image): The input image to classify.
-
- Returns:
- float: The prediction score indicating the likelihood of the image being a deepfake.
- """
- gate_results = self.gating_mechanism(image)
- expert_outputs = {}
- for content_type, gate_output_image in gate_results.items():
- pred = self.detectors[content_type](gate_output_image)
- expert_outputs[content_type] = pred
-
- if len(expert_outputs) == 0:
- return self.detectors['general'](image)
-
- return max(expert_outputs.values())
diff --git a/base_miner/deepfake_detectors/configs/camo.yaml b/base_miner/deepfake_detectors/configs/camo.yaml
deleted file mode 100644
index 762b2bb1..00000000
--- a/base_miner/deepfake_detectors/configs/camo.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-# CAMO Configuration
-
-content_type:
- general:
- model: 'UCF'
- detector_config: 'ucf.yaml' # Default model for 'general'
- face:
- model: 'UCF'
- detector_config: 'ucf_face.yaml' # Default model for 'face'
- use_gate: True
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/configs/npr.yaml b/base_miner/deepfake_detectors/configs/npr.yaml
deleted file mode 100644
index ec525946..00000000
--- a/base_miner/deepfake_detectors/configs/npr.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-# NPR Generalist Configuration
-hf_repo: 'bitmind/npr' # Hugging Face repository for downloading model files
-weights: 'npr.pth'
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/configs/tall.yaml b/base_miner/deepfake_detectors/configs/tall.yaml
deleted file mode 100644
index d1248a24..00000000
--- a/base_miner/deepfake_detectors/configs/tall.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-hf_repo: 'bitmind/tall' # Hugging Face repository for downloading model files
-config_name: 'tall.yaml' # pre-trained configuration file in HuggingFace
-weights: 'tall_trainFF_testCDF.pth' # UCF model checkpoint in HuggingFace
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/configs/ucf.yaml b/base_miner/deepfake_detectors/configs/ucf.yaml
deleted file mode 100644
index db9978c8..00000000
--- a/base_miner/deepfake_detectors/configs/ucf.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-# UCFDetector Generalist Configuration
-hf_repo: 'bitmind/bm-ucf' # Hugging Face repository for downloading model files
-config_name: 'bm-general-config-v1.yaml' # pre-trained configuration file in HuggingFace
-weights: 'bm-general-v1.pth' # UCF model checkpoint in HuggingFace
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/configs/ucf_face.yaml b/base_miner/deepfake_detectors/configs/ucf_face.yaml
deleted file mode 100644
index 4cd4c5b6..00000000
--- a/base_miner/deepfake_detectors/configs/ucf_face.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-# UCFDetector Face Expert Configuration
-hf_repo: 'bitmind/bm-ucf' # Hugging Face repository for downloading model files
-config_name: 'bm-faces-config-v1.yaml' # pre-trained configuration file in HuggingFace
-weights: 'bm-faces-v1.pth' # UCF model checkpoint in HuggingFace
diff --git a/base_miner/deepfake_detectors/deepfake_detector.py b/base_miner/deepfake_detectors/deepfake_detector.py
deleted file mode 100644
index 16bf0cd7..00000000
--- a/base_miner/deepfake_detectors/deepfake_detector.py
+++ /dev/null
@@ -1,153 +0,0 @@
-from abc import ABC, abstractmethod
-from pathlib import Path
-from typing import Optional, Dict, Any
-
-import torch
-import yaml
-import bittensor as bt
-from PIL import Image
-from huggingface_hub import hf_hub_download
-
-from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR
-
-
-class DeepfakeDetector(ABC):
- """Abstract base class for detecting deepfake images via binary classification.
-
- This class is intended to be subclassed by detector implementations
- using different underlying model architectures, routing via gates, or
- configurations.
-
- Attributes:
- model_name (str): Name of the detector instance.
- config_name (Optional[str]): Name of the YAML file in deepfake_detectors/config/
- to load instance attributes from.
- device (str): The type of device ('cpu' or 'cuda').
- hf_repo (str): Hugging Face repository name for model weights.
- """
-
- def __init__(
- self,
- model_name: str,
- config_name: Optional[str] = None,
- device: str = 'cpu'
- ) -> None:
- """Initialize the DeepfakeDetector.
-
- Args:
- model_name: Name of the detector instance.
- config: Optional name of configuration file to load.
- device: Device to run the model on ('cpu' or 'cuda').
- """
- self.model_name = model_name
- self.device = torch.device(
- device if device == 'cuda' and torch.cuda.is_available() else 'cpu'
- )
-
- if config_name:
- print(f"Configuring with {config_name}")
- self.set_class_attrs(config_name)
- self.load_model_config()
-
- self.load_model()
-
- @abstractmethod
- def load_model(self) -> None:
- """Load the model weights and architecture.
-
- This method should be implemented by subclasses to define their specific
- model loading logic.
- """
- pass
-
- def preprocess(self, image: Image.Image) -> torch.Tensor:
- """Preprocess the image for model inference.
-
- Args:
- image: The input image to preprocess.
-
- Returns:
- The preprocessed image as a tensor ready for model input.
- """
- # General preprocessing, to be overridden if necessary in subclasses
- pass
-
- @abstractmethod
- def __call__(self, image: Image.Image) -> float:
- """Perform inference with the model.
-
- Args:
- image: The preprocessed input image.
-
- Returns:
- The model's prediction score (typically between 0 and 1).
- """
- pass
-
- def set_class_attrs(self, detector_config: str) -> None:
- """Load detector configuration from YAML file and set attributes.
-
- Args:
- detector_config: Path to the YAML configuration file or filename
- in the configs directory.
-
- Raises:
- Exception: If there is an error loading or parsing the config file.
- """
- if Path(detector_config).exists():
- detector_config_file = Path(detector_config)
- else:
- detector_config_file = (
- Path(__file__).resolve().parent / Path('configs/' + detector_config)
- )
-
- try:
- with open(detector_config_file, 'r', encoding='utf-8') as file:
- config_dict = yaml.safe_load(file)
-
- # Set class attributes dynamically from the config dictionary
- for key, value in config_dict.items():
- setattr(self, key, value)
-
- except Exception as e:
- print(f"Error loading detector configurations from {detector_config_file}: {e}")
- raise
-
- def ensure_weights_are_available(
- self,
- weights_dir: str,
- weights_filename: str
- ) -> None:
- """Ensure model weights are downloaded and available locally.
-
- Downloads weights from Hugging Face Hub if not found locally.
-
- Args:
- weights_dir: Directory to store/find the weights.
- weights_filename: Name of the weights file.
- """
- destination_path = Path(weights_dir) / Path(weights_filename)
- if not Path(weights_dir).exists():
- Path(weights_dir).mkdir(parents=True, exist_ok=True)
-
- if not destination_path.exists():
- print(f"Downloading {weights_filename} from {self.hf_repo} "
- f"to {weights_dir}")
- hf_hub_download(self.hf_repo, weights_filename, local_dir=weights_dir)
-
- def load_model_config(self):
- try:
- destination_path = Path(CONFIGS_DIR) / Path(self.config_name)
- if not destination_path.exists():
- local_config_path = hf_hub_download(self.hf_repo, self.config_name, local_dir=CONFIGS_DIR)
- print(f"Downloaded {self.hf_repo}/{self.config_name} to {local_config_path}")
- with Path(local_config_path).open('r') as f:
- self.config = yaml.safe_load(f)
- else:
- print(f"Loading local config from {destination_path}")
- with destination_path.open('r') as f:
- self.config = yaml.safe_load(f)
- print(f"Loaded: {self.config}")
- except Exception as e:
- # some models such as NPR don't have an additional config file
- bt.logging.warning("No additional train config loaded.")
diff --git a/base_miner/deepfake_detectors/npr_detector.py b/base_miner/deepfake_detectors/npr_detector.py
deleted file mode 100644
index a5fe3500..00000000
--- a/base_miner/deepfake_detectors/npr_detector.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import torch
-import numpy as np
-from PIL import Image
-from pathlib import Path
-from huggingface_hub import hf_hub_download
-from base_miner.NPR.networks.resnet import resnet50
-from bitmind.utils.image_transforms import get_base_transforms
-from base_miner.deepfake_detectors import DeepfakeDetector
-from base_miner.registry import DETECTOR_REGISTRY
-from base_miner.NPR.config.constants import WEIGHTS_DIR
-
-
-@DETECTOR_REGISTRY.register_module(module_name='NPR')
-class NPRImageDetector(DeepfakeDetector):
- """
- DeepfakeDetector subclass that initializes a pretrained NPR model
- for binary classification of fake and real images.
-
- Attributes:
- model_name (str): Name of the detector instance.
- config_name (str): Name of the YAML file in deepfake_detectors/config/ to load
- attributes from.
- device (str): The type of device ('cpu' or 'cuda').
- """
-
- def __init__(self, model_name: str = 'NPR', config_name: str = 'npr.yaml', device: str = 'cpu'):
- super().__init__(model_name, config_name, device)
- self.transforms = get_base_transforms()
-
- def load_model(self):
- """
- Load the ResNet50 model with the specified weights for deepfake detection.
- """
- self.ensure_weights_are_available(WEIGHTS_DIR, self.weights)
- self.model = resnet50(num_classes=1)
- self.model.load_state_dict(torch.load(Path(WEIGHTS_DIR) / self.weights, map_location=self.device))
- self.model.eval()
-
- def preprocess(self, image: Image) -> torch.Tensor:
- """
- Preprocess the image using the base_transforms function.
-
- Args:
- image (PIL.Image): The image to preprocess.
-
- Returns:
- torch.Tensor: The preprocessed image tensor.
- """
- image_tensor = self.transforms(image).unsqueeze(0).float()
- return image_tensor
-
- def __call__(self, image: Image) -> float:
- """
- Perform inference with the model.
-
- Args:
- image (PIL.Image): The image to process.
-
- Returns:
- float: The prediction score indicating the likelihood of the image being a deepfake.
- """
- image_tensor = self.preprocess(image)
- with torch.no_grad():
- out = np.asarray(self.model(image_tensor).sigmoid().flatten())
- return out
diff --git a/base_miner/deepfake_detectors/tall_detector.py b/base_miner/deepfake_detectors/tall_detector.py
deleted file mode 100644
index 7b4bd40a..00000000
--- a/base_miner/deepfake_detectors/tall_detector.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch
-from pathlib import Path
-
-import bittensor as bt
-from base_miner.registry import DETECTOR_REGISTRY
-from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR
-from base_miner.DFB.detectors import DETECTOR, TALLDetector
-from base_miner.deepfake_detectors import DeepfakeDetector
-from bitmind.utils.video_utils import pad_frames
-
-
-@DETECTOR_REGISTRY.register_module(module_name="TALL")
-class TALLVideoDetector(DeepfakeDetector):
- def __init__(
- self,
- model_name: str = "TALL",
- config_name: str = "tall.yaml",
- device: str = "cpu",
- ):
- super().__init__(model_name, config_name, device)
-
- total_params = sum(p.numel() for p in self.tall.model.parameters())
- trainable_params = sum(
- p.numel() for p in self.tall.model.parameters() if p.requires_grad
- )
- bt.logging.info('device:', self.device)
- bt.logging.info(total_params, "parameters")
- bt.logging.info(trainable_params, "trainable parameters")
-
- def load_model(self):
- # download weights from hf if not available locally
- self.ensure_weights_are_available(WEIGHTS_DIR, self.weights)
- bt.logging.info(f"Loaded config: {self.config}")
- self.tall = TALLDetector(self.config, self.device)
-
- # load weights
- checkpoint_path = Path(WEIGHTS_DIR) / self.weights
- checkpoint = torch.load(checkpoint_path, map_location=self.device)
- self.tall.load_state_dict(checkpoint, strict=True)
- self.tall.model.eval()
-
- def preprocess(self, frames_tensor):
- """ Prepare input data dict for TALLDetector """
- frames_tensor = pad_frames(frames_tensor, 4)
- return {'image': frames_tensor}
-
- def __call__(self, frames_tensor):
- input_data = self.preprocess(frames_tensor)
- with torch.no_grad():
- output_data = self.tall.forward(input_data, inference=True)
- return output_data['prob'][0]
diff --git a/base_miner/deepfake_detectors/ucf_detector.py b/base_miner/deepfake_detectors/ucf_detector.py
deleted file mode 100644
index 92b40b54..00000000
--- a/base_miner/deepfake_detectors/ucf_detector.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import os
-os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Ignore INFO and WARN messages
-
-import random
-import warnings
-warnings.filterwarnings("ignore", category=FutureWarning)
-from huggingface_hub import hf_hub_download
-from pathlib import Path
-from PIL import Image
-import torchvision.transforms as transforms
-import torch.backends.cudnn as cudnn
-import bittensor as bt
-import numpy as np
-import torch
-import yaml
-import gc
-
-from base_miner.DFB.config.constants import CONFIGS_DIR, WEIGHTS_DIR
-from base_miner.deepfake_detectors import DeepfakeDetector
-from base_miner.DFB.detectors import UCFDetector
-from base_miner.registry import DETECTOR_REGISTRY
-
-
-@DETECTOR_REGISTRY.register_module(module_name='UCF')
-class UCFImageDetector(DeepfakeDetector):
- """
- DeepfakeDetector subclass that initializes a pretrained UCF model
- for binary classification of fake and real images.
-
- Attributes:
- model_name (str): Name of the detector instance.
- config_name (str): Name of the YAML file in deepfake_detectors/config/ to load
- attributes from.
- device (str): The type of device ('cpu' or 'cuda').
- """
-
- def __init__(self, model_name: str = 'UCF', config_name: str = 'ucf.yaml', device: str = 'cpu'):
- super().__init__(model_name, config_name, device)
-
- def init_cudnn(self):
- if self.config.get('cudnn'):
- cudnn.benchmark = True
-
- def init_seed(self):
- seed_value = self.config.get('manualSeed')
- if seed_value:
- random.seed(seed_value)
- torch.manual_seed(seed_value)
- torch.cuda.manual_seed_all(seed_value)
-
- def load_model(self):
- self.init_cudnn()
- self.init_seed()
- self.ensure_weights_are_available(WEIGHTS_DIR, self.weights)
- pretrained = self.config['pretrained']
- if isinstance(pretrained, dict) and 'filename' in pretrained:
- pretrained = pretrained['filename']
- else:
- pretrained = pretrained.split('/')[-1]
-
- self.ensure_weights_are_available(WEIGHTS_DIR, pretrained)
- self.model = UCFDetector(self.config).to(self.device)
- self.model.eval()
- weights_path = Path(WEIGHTS_DIR) / self.weights
- bt.logging.info(f"Loading checkpoint {weights_path}")
- checkpoint = torch.load(weights_path, map_location=self.device)
- try:
- self.model.load_state_dict(checkpoint, strict=True)
- except RuntimeError as e:
- if 'size mismatch' in str(e):
- # Create a custom error message
- custom_message = (
- "\n\n Error: Incorrect specific_task_num in model config. The 'specific_task_num' "
- "in 'config_path' yaml should match the value used during training. "
- "A mismatch results in an incorrect output layer shape for UCF's learned disentanglement"
- " of different forgery methods/sources.\n\n"
- "Solution: Use the same config.yaml to intialize UCFDetector ('config_path' arg) "
- "as output during training (config.yaml saved alongside weights in the training run's "
- "logs directory). Or simply modify your config.yaml to ensure 'specific_task_num' equals "
- "the value set during training (defaults to num fake training datasets + 1).\n"
- )
- raise RuntimeError(custom_message) from e
- else: raise e
-
- def preprocess(self, image, res=256):
- """Preprocess the image for model inference.
-
- Returns:
- torch.Tensor: The preprocessed image tensor, ready for model inference.
- """
- # Convert image to RGB format to ensure consistent color handling.
- image = image.convert('RGB')
-
- # Define transformation sequence for image preprocessing.
- transform = transforms.Compose([
- transforms.Resize((res, res), interpolation=Image.LANCZOS), # Resize image to specified resolution.
- transforms.ToTensor(), # Convert the image to a PyTorch tensor.
- transforms.Normalize(mean=self.config['mean'], std=self.config['std']) # Normalize the image tensor.
- ])
-
- # Apply transformations and add a batch dimension for model inference.
- image_tensor = transform(image).unsqueeze(0)
-
- # Move the image tensor to the specified device (e.g., GPU).
- return image_tensor.to(self.device)
-
- def infer(self, image_tensor):
- """ Perform inference using the model. """
- with torch.no_grad():
- self.model({'image': image_tensor}, inference=True)
- return self.model.prob[-1]
-
- def __call__(self, image: Image) -> float:
- image_tensor = self.preprocess(image)
- return self.infer(image_tensor)
-
- def free_memory(self):
- """ Frees up memory by setting model and large data structures to None. """
- if self.model is not None:
- self.model.cpu() # Move model to CPU to free up GPU memory (if applicable)
- del self.model
- self.model = None
-
- if self.face_detector is not None:
- del self.face_detector
- self.face_detector = None
-
- if self.face_predictor is not None:
- del self.face_predictor
- self.face_predictor = None
-
- gc.collect()
-
- # If using GPUs and PyTorch, clear the cache as well
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
diff --git a/base_miner/deepfake_detectors/unit_tests/__init__.py b/base_miner/deepfake_detectors/unit_tests/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/base_miner/deepfake_detectors/unit_tests/base_npr_weights.pth b/base_miner/deepfake_detectors/unit_tests/base_npr_weights.pth
deleted file mode 100644
index 646d7e15..00000000
Binary files a/base_miner/deepfake_detectors/unit_tests/base_npr_weights.pth and /dev/null differ
diff --git a/base_miner/deepfake_detectors/unit_tests/run_all_unit_tests.py b/base_miner/deepfake_detectors/unit_tests/run_all_unit_tests.py
deleted file mode 100644
index 1e79a610..00000000
--- a/base_miner/deepfake_detectors/unit_tests/run_all_unit_tests.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import os
-import subprocess
-
-def run_all_py_scripts(directory):
- # List all files in the directory
- for filename in os.listdir(directory):
- # Check if the file ends with .py and is not this script itself
- if filename.endswith('.py') and filename != os.path.basename(__file__):
- # Full path of the python file
- filepath = os.path.join(directory, filename)
- print(f"Running {filename}...")
-
- # Run the script using subprocess
- subprocess.run(['python', filepath])
-
-if __name__ == "__main__":
- # Run all python files in the current directory
- run_all_py_scripts(os.getcwd())
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/unit_tests/sample_image.jpg b/base_miner/deepfake_detectors/unit_tests/sample_image.jpg
deleted file mode 100644
index fd2fff10..00000000
Binary files a/base_miner/deepfake_detectors/unit_tests/sample_image.jpg and /dev/null differ
diff --git a/base_miner/deepfake_detectors/unit_tests/test_camo_detector.py b/base_miner/deepfake_detectors/unit_tests/test_camo_detector.py
deleted file mode 100644
index 9b0c3e61..00000000
--- a/base_miner/deepfake_detectors/unit_tests/test_camo_detector.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import unittest
-import torch
-import numpy as np
-from PIL import Image
-import os
-import sys
-#CAMODetector class located in the parent directory
-directory = os.path.dirname(os.path.abspath(__file__))
-parent_directory = os.path.dirname(directory)
-sys.path.append(parent_directory)
-from camo_detector import CAMODetector
-
-
-class TestCAMODetector(unittest.TestCase):
- def setUp(self):
- """Set up the necessary information to test CAMODetector."""
- self.script_dir = os.path.dirname(__file__)
- # Set the path of the sample image
- self.image_path = os.path.join(self.script_dir, 'sample_image.jpg')
- self.camo_detector = CAMODetector()
-
- def test_load_model(self):
- """Test if the models load properly with the given weight paths."""
- self.assertIsNotNone(self.camo_detector.detectors['face'], "Face detector should not be None")
- self.assertIsNotNone(self.camo_detector.detectors['general'], "General detector should not be None")
-
- def test_load_gates(self):
- """Test if the models load properly with the given weight paths."""
- self.assertIsNotNone(self.camo_detector.gating_mechanism, "GatingMechanism gates not be None")
-
- def test_call(self):
- """Test the __call__ method for inference on a given image."""
- image = Image.open(self.image_path)
- prediction = self.camo_detector(image)
- print(f"Prediction: {prediction}")
- self.assertIsNotNone(prediction, "Inference output should not be None")
- self.assertIsInstance(prediction, np.ndarray, "Output should be a np.ndarray containing a float value")
- self.assertTrue(0 <= prediction <= 1, "Output should be between 0 and 1")
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/unit_tests/test_npr_detector.py b/base_miner/deepfake_detectors/unit_tests/test_npr_detector.py
deleted file mode 100644
index b8e94533..00000000
--- a/base_miner/deepfake_detectors/unit_tests/test_npr_detector.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import unittest
-import torch
-from PIL import Image
-import os
-import sys
-import numpy as np
-# NPRDetector class located in the parent directory
-directory = os.path.dirname(os.path.abspath(__file__))
-parent_directory = os.path.dirname(directory)
-sys.path.append(parent_directory)
-from npr_detector import NPRDetector
-
-
-class TestNPRDetector(unittest.TestCase):
- def setUp(self):
- """Set up the necessary information to test NPRDetector."""
- self.script_dir = os.path.dirname(__file__)
- # Set the path of the sample image
- self.image_path = os.path.join(self.script_dir, 'sample_image.jpg')
- self.npr_detector = NPRDetector()
-
- def test_load_model(self):
- """Test if the model loads properly with the given weight path."""
- self.assertIsNotNone(self.npr_detector.model, "Model should not be None")
-
- def test_preprocess(self):
- """Test image preprocessing."""
- image = Image.open(self.image_path)
- tensor = self.npr_detector.preprocess(image)
- print(f"Preprocessed tensor: {tensor}")
- self.assertIsInstance(tensor, torch.Tensor, "Output should be a torch.Tensor")
- self.assertEqual(tensor.dim(), 4, "Tensor should have a dimension of 4")
- self.assertEqual(tensor.shape[1], 3, "Tensor should have 3 channels")
-
- def test_inference(self):
- """Test model inference on a preprocessed image."""
- image = Image.open(self.image_path)
- prediction = self.npr_detector(image)
- print(f"Prediction: {prediction}, Type: {type(prediction)}")
- self.assertIsInstance(prediction, np.ndarray, "Output should be a np.ndarray containing a float value")
- self.assertTrue(0 <= prediction <= 1, "Output should be between 0 and 1")
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/unit_tests/test_registry.py b/base_miner/deepfake_detectors/unit_tests/test_registry.py
deleted file mode 100644
index 26932e24..00000000
--- a/base_miner/deepfake_detectors/unit_tests/test_registry.py
+++ /dev/null
@@ -1,42 +0,0 @@
-import unittest
-import os
-import sys
-
-#Registry class located in the parent directory
-directory = os.path.dirname(os.path.abspath(__file__))
-parent_directory = os.path.dirname(directory)
-sys.path.append(parent_directory)
-
-# Unit test class to test DETECTOR_REGISTRY
-class TestDetectorRegistry(unittest.TestCase):
-
- def test_registry_contents(self):
- from base_miner.registry import Registry
- detector_registry = Registry()
- # Check if the registry has the expected keys (class names or custom names)
- registered_keys = list(detector_registry.data.keys())
-
- # Print all the registered models
- print("Registered detectors:")
- for name in registered_keys:
- print(f"Detector Name: {name}, Class: {detector_registry[name]}")
-
- # Assert that all expected keys are present
- self.assertEqual(len(registered_keys), 0, "There should be no registered detectors.")
-
- def test_registry_contents_after_import(self):
- from base_miner import DETECTOR_REGISTRY
- # Check if the registry has the expected keys (class names or custom names)
- registered_keys = list(DETECTOR_REGISTRY.data.keys())
-
- # Print all the registered models
- print("Registered detectors:")
- for name in registered_keys:
- print(f"Detector Name: {name}, Class: {DETECTOR_REGISTRY[name]}")
-
- # Assert that all expected keys are present
- self.assertIsNotNone(registered_keys, "Registered detectors should not be None")
-
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
diff --git a/base_miner/deepfake_detectors/unit_tests/test_ucf_detector.py b/base_miner/deepfake_detectors/unit_tests/test_ucf_detector.py
deleted file mode 100644
index cf87138b..00000000
--- a/base_miner/deepfake_detectors/unit_tests/test_ucf_detector.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import unittest
-import os
-import sys
-from PIL import Image
-import numpy as np
-from pathlib import Path
-from base_miner.UCF.config.constants import WEIGHTS_DIR
-#UCFDetector class located in the parent directory
-directory = os.path.dirname(os.path.abspath(__file__))
-parent_directory = os.path.dirname(directory)
-sys.path.append(parent_directory)
-from ucf_detector import UCFDetector
-
-
-class TestUCFDetector(unittest.TestCase):
- def setUp(self):
- """Set up the necessary information to test UCFDetector."""
- # Set up a test instance of the UCFDetector class
- self.ucf_detector = UCFDetector()
- self.ucf_detector_face = UCFDetector(config='ucf_face.yaml')
- # Set the path of the sample image
- self.image_path = os.path.join(os.path.dirname(__file__), 'sample_image.jpg')
-
- def test_load_config(self):
- """Test if the configuration is loaded properly."""
- self.assertIsNotNone(self.ucf_detector.train_config, "Generaliist config should not be None")
- self.assertIsNotNone(self.ucf_detector_face.train_config, "Face config should not be None")
-
- def test_ensure_weights(self):
- """Test if the weights are checked and downloaded if missing."""
- self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector.weights).exists(),
- "Model weights should be available after initialization.")
- self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector.train_config['pretrained'].split('/')[-1]).exists(),
- "Backbone weights should be available after initialization.")
- self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector_face.weights).exists(),
- "Face model weights should be available after initialization.")
- self.assertTrue((Path(WEIGHTS_DIR) / self.ucf_detector_face.train_config['pretrained'].split('/')[-1]).exists(),
- "Face backbone weights should be available after initialization.")
-
- def test_model_loading(self):
- """Test if the model is loaded properly."""
- self.assertIsNotNone(self.ucf_detector.model, "Generalist model should not be None")
- self.assertIsNotNone(self.ucf_detector_face.model, "Face model should not be None")
-
- def test_infer_general(self):
- """Test a basic inference to ensure model outputs are correct."""
- image = Image.open(self.image_path)
- preprocessed_image = self.ucf_detector.preprocess(image)
- output = self.ucf_detector.infer(preprocessed_image)
- print(f"General Output: {output}")
- self.assertIsNotNone(output, "Inference output should not be None")
- self.assertIsInstance(output, np.ndarray, "Output should be a np.ndarray containing a float value")
-
- def test_infer_general_call(self):
- """Test the __call__ method to ensure inference is correct."""
- image = Image.open(self.image_path)
- output = self.ucf_detector(image)
- print(f"General __call__ method output: {output}")
- self.assertIsNotNone(output, "Inference output should not be None")
- self.assertIsInstance(output, np.ndarray, "Output should be a np.ndarray containing a float value")
-
- def test_infer_face(self):
- """Test a basic inference to ensure model outputs are correct."""
- image = Image.open(self.image_path)
- preprocessed_image = self.ucf_detector_face.preprocess(image)
- output = self.ucf_detector_face.infer(preprocessed_image)
- print(f"Face Output: {output}")
- self.assertIsNotNone(output, "Inference output should not be None")
- self.assertIsInstance(output, np.ndarray, "Output should be a np.ndarray containing a float value")
-
- def test_infer_face_call(self):
- """Test the __call__ method to ensure inference is correct."""
- image = Image.open(self.image_path)
- output = self.ucf_detector_face(image)
- print(f"Face __call__ method output: {output}")
- self.assertIsNotNone(output, "Inference output should not be None")
- self.assertIsInstance(output, np.ndarray, "Output should be a np.ndarray containing a float value")
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
diff --git a/base_miner/gating_mechanisms/__init__.py b/base_miner/gating_mechanisms/__init__.py
deleted file mode 100644
index dc2e79be..00000000
--- a/base_miner/gating_mechanisms/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .gate import Gate
-from .face_gate import FaceGate
-from .gating_mechanism import GatingMechanism
\ No newline at end of file
diff --git a/base_miner/gating_mechanisms/face_gate.py b/base_miner/gating_mechanisms/face_gate.py
deleted file mode 100644
index d433056d..00000000
--- a/base_miner/gating_mechanisms/face_gate.py
+++ /dev/null
@@ -1,72 +0,0 @@
-from PIL import Image
-from _dlib_pybind11 import rectangles
-import numpy as np
-import dlib
-
-from base_miner.gating_mechanisms import Gate
-from base_miner.DFB.config.constants import DLIB_FACE_PREDICTOR_PATH
-from base_miner.registry import GATE_REGISTRY
-from base_miner.gating_mechanisms.utils import get_face_landmarks, align_and_crop_face
-
-
-@GATE_REGISTRY.register_module(module_name='FACE')
-class FaceGate(Gate):
- """
- Gate subclass for face content detection and preprocessing.
-
- Attributes:
- gate_name (str): The name of the gate.
- predictor_path (str): Path to dlib face landmark model.
- """
-
- def __init__(self, gate_name: str = 'FaceGate', predictor_path=DLIB_FACE_PREDICTOR_PATH):
- self.face_detector = dlib.get_frontal_face_detector()
- self.face_predictor = dlib.shape_predictor(predictor_path)
- super().__init__(gate_name, "face")
-
- def preprocess(self, image: np.ndarray, faces: rectangles, res=256) -> any:
- """
- Align and crop the largest face in the image
-
- Args:
- image: Input image array
- faces: Output out of a dlib face detection model
- res: NxN image size
-
- Returns:
- preprocessed image with largest face aligned and cropped
- """
-
- # For now only take the biggest face
- face = max(faces, key=lambda rect: rect.width() * rect.height())
-
- # Get the landmarks/parts for the face in box d only with the five key points
- face_shape = self.face_predictor(image, face)
- landmarks = get_face_landmarks(face_shape)
- cropped_face, _ = align_and_crop_face(image, landmarks, outsize=(res, res))
- # return cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR)
-
- # Convert the cropped face back to a PIL Image if cropping was successful.
- if cropped_face is not None:
- image = Image.fromarray(cropped_face)
- else:
- print("Largest face was not successfully cropped.")
- return image
-
- def __call__(self, image: Image, res: int = 256) -> any:
- """
- Perform face detection and image aligning and cropping to the face.
-
- Args:
- image (PIL.Image): The image to classify and preprocess if content is detected.
-
- Returns:
- image (PIL.Image): The processed face image or original image if no faces.
- """
- image_np = np.array(image)
- faces = self.face_detector(image_np, 1)
- if faces is None or len(faces) == 0:
- return image, False
-
- processed_image = self.preprocess(image_np, faces, res)
- return processed_image, True
diff --git a/base_miner/gating_mechanisms/gate.py b/base_miner/gating_mechanisms/gate.py
deleted file mode 100644
index c73ffcc0..00000000
--- a/base_miner/gating_mechanisms/gate.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from PIL import Image
-from abc import ABC, abstractmethod
-import numpy as np
-
-
-class Gate(ABC):
- """
- Abstract base class for image content detection and preprocessing.
- Used to route deepfake detection inference inputs to tailored models
- in a single agent or mixture-of-experts design.
-
- This class is intended to be subclassed by specific gate
- implementations that handle different content types.
-
- Attributes:
- gate_name (str): The name of the gate.
- content_type (str): The type of content handled by the gate.
- """
-
- def __init__(self, gate_name: str, content_type: str):
- self.gate_name = gate_name
- self.content_type = content_type
-
- @abstractmethod
- def preprocess(self, image: np.array) -> any:
- """Preprocess the image based on its content type."""
- return image
-
- @abstractmethod
- def __call__(self, image: Image) -> any:
- """
- Perform content classification and content-specific preprocessing.
- Used to route inputs to appropriate models for inference.
-
- Args:
- image (PIL.Image): The image to preprocess.
- """
- pass
\ No newline at end of file
diff --git a/base_miner/gating_mechanisms/gating_mechanism.py b/base_miner/gating_mechanisms/gating_mechanism.py
deleted file mode 100644
index 58ff7050..00000000
--- a/base_miner/gating_mechanisms/gating_mechanism.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from PIL import Image
-from base_miner.registry import GATE_REGISTRY
-
-
-class GatingMechanism:
- """
- This class orchestrates multi-gate content detection and content-specific
- preprocessing to facilitate use by downstream models
-
- This is useful for routing images to appropriate detectors
- trained to handle different content types in a mixture-of-experts
- framework such as Content-Aware Model Orchestration (CAMO).
- """
- def __init__(self, gate_names: list):
- self.gates = {
- gate: GATE_REGISTRY[gate.upper()]()
- for gate in gate_names
- }
-
- def __call__(self, image: Image):
- gate_results = {}
- for gate_name, gate in self.gates.items():
- gate_output_image, gate_activated = gate(image)
- if gate_activated:
- gate_results[gate_name] = gate_output_image
-
- return gate_results
diff --git a/base_miner/gating_mechanisms/utils/__init__.py b/base_miner/gating_mechanisms/utils/__init__.py
deleted file mode 100644
index 2436a5ae..00000000
--- a/base_miner/gating_mechanisms/utils/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .face_utils import get_face_landmarks, align_and_crop_face
\ No newline at end of file
diff --git a/base_miner/gating_mechanisms/utils/face_utils.py b/base_miner/gating_mechanisms/utils/face_utils.py
deleted file mode 100644
index e30621b8..00000000
--- a/base_miner/gating_mechanisms/utils/face_utils.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from skimage import transform as trans
-import numpy as np
-import cv2
-
-
-def get_face_landmarks(face):
- """
- Args:
- face: dlib face rectangle object for face
-
- Returns:
- numpy array containing key points for eyes, nose and mouth
- """
- leye = np.array([face.part(37).x, face.part(37).y]).reshape(-1, 2)
- reye = np.array([face.part(44).x, face.part(44).y]).reshape(-1, 2)
- nose = np.array([face.part(30).x, face.part(30).y]).reshape(-1, 2)
- lmouth = np.array([face.part(49).x, face.part(49).y]).reshape(-1, 2)
- rmouth = np.array([face.part(55).x, face.part(55).y]).reshape(-1, 2)
- return np.concatenate([leye, reye, nose, lmouth, rmouth], axis=0)
-
-
-def align_and_crop_face(
- img: np.ndarray,
- landmarks: np.ndarray, outsize: tuple, scale=1.3, mask=None):
- """
- Align and crop the face according to the given landmarks
- Args:
- img: input image containing the face
- landmarks: 5 key points of face, determined by get_face_landmarks
- outsize: size to use in scaling
- scale: margin
- mask: optional face mask to transform alongside the face
-
- Returns:
- cropped and aligned face, optionally with a correspondingly
- cropped and aligned mask
- """
- target_size = [112, 112]
- dst = np.array([
- [30.2946, 51.6963],
- [65.5318, 51.5014],
- [48.0252, 71.7366],
- [33.5493, 92.3655],
- [62.7299, 92.2041]], dtype=np.float32)
-
- if target_size[1] == 112:
- dst[:, 0] += 8.0
-
- dst[:, 0] = dst[:, 0] * outsize[0] / target_size[0]
- dst[:, 1] = dst[:, 1] * outsize[1] / target_size[1]
-
- target_size = outsize
-
- margin_rate = scale - 1
- x_margin = target_size[0] * margin_rate / 2.
- y_margin = target_size[1] * margin_rate / 2.
-
- # move
- dst[:, 0] += x_margin
- dst[:, 1] += y_margin
-
- # resize
- dst[:, 0] *= target_size[0] / (target_size[0] + 2 * x_margin)
- dst[:, 1] *= target_size[1] / (target_size[1] + 2 * y_margin)
-
- src = landmarks.astype(np.float32)
-
- # use skimage transformation
- tform = trans.SimilarityTransform()
- tform.estimate(src, dst)
- M = tform.params[0:2, :]
-
- img = cv2.warpAffine(img, M, (target_size[1], target_size[0]))
-
- if outsize is not None:
- img = cv2.resize(img, (outsize[1], outsize[0]))
-
- if mask is not None:
- mask = cv2.warpAffine(mask, M, (target_size[1], target_size[0]))
- mask = cv2.resize(mask, (outsize[1], outsize[0]))
- return img, mask
- else:
- return img, None
diff --git a/base_miner/registry.py b/base_miner/registry.py
deleted file mode 100644
index 628b2b48..00000000
--- a/base_miner/registry.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import os
-import importlib
-import inspect
-
-class Registry(object):
- def __init__(self):
- self.data = {}
-
- def register_module(self, module_name=None):
- def _register(cls):
- name = module_name
- if module_name is None:
- name = cls.__name__
- self.data[name] = cls
- return cls
- return _register
-
- def __getitem__(self, key):
- return self.data[key]
-
- def __contains__(self, key):
- return key in self.data
-
-DETECTOR_REGISTRY = Registry()
-GATE_REGISTRY = Registry()
\ No newline at end of file
diff --git a/bitmind/__init__.py b/bitmind/__init__.py
index 7cd2bfa6..00b4d740 100644
--- a/bitmind/__init__.py
+++ b/bitmind/__init__.py
@@ -1,27 +1,8 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# developer: dubm
-# Copyright © 2023 Bitmind
+__version__ = "3.0.0"
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-
-__version__ = "2.2.11"
version_split = __version__.split(".")
__spec_version__ = (
- (1000 * int(version_split[0]))
- + (10 * int(version_split[1]))
- + (1 * int(version_split[2]))
+ (100000 * int(version_split[0]))
+ + (1000 * int(version_split[1]))
+ + (10 * int(version_split[2]))
)
diff --git a/bitmind/autoupdater.py b/bitmind/autoupdater.py
new file mode 100644
index 00000000..a670c15c
--- /dev/null
+++ b/bitmind/autoupdater.py
@@ -0,0 +1,78 @@
+# The MIT License (MIT)
+# Copyright © 2023 Yuma Rao
+# Copyright © 2024 Manifold Labs
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
+# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
+# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
+# the Software.
+
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
+# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+
+import signal
+import time
+import os
+import requests
+import bittensor as bt
+import bitmind
+
+
+def autoupdate(branch: str = "v3", force=False):
+ """
+ Automatically updates the codebase to the latest version available on the specified branch.
+
+ This function checks the remote repository for the latest version by fetching the VERSION file from the specified branch.
+ If the local version is older than the remote version, it performs a git pull to update the local codebase to the latest version.
+ After successfully updating, it restarts the application with the updated code.
+
+ Args:
+ - branch (str): The name of the branch to check for updates. Defaults to "main".
+
+ Note:
+ - The function assumes that the local codebase is a git repository and has the same structure as the remote repository.
+ - It requires git to be installed and accessible from the command line.
+ - The function will restart the application using the same command-line arguments it was originally started with.
+ - If the update fails, manual intervention is required to resolve the issue and restart the application.
+ """
+ bt.logging.info("Checking for updates...")
+ try:
+ github_url = f"https://raw.githubusercontent.com/BitMind-AI/bitmind-subnet/{branch}/VERSION?ts={time.time()}"
+ bt.logging.info(github_url)
+ response = requests.get(
+ github_url,
+ headers={"Cache-Control": "no-cache"},
+ )
+ response.raise_for_status()
+ repo_version = response.content.decode()
+ latest_version = int("".join(repo_version.split(".")))
+ local_version = int("".join(bitmind.__version__.split(".")))
+
+ bt.logging.info(f"Local version: {bitmind.__version__}")
+ bt.logging.info(f"Latest version: {repo_version}")
+
+ if latest_version > local_version or force:
+ bt.logging.info(f"A newer version is available. Updating...")
+ base_path = os.path.abspath(__file__)
+ while os.path.basename(base_path) != "bitmind-subnet":
+ base_path = os.path.dirname(base_path)
+
+ os.system(f"cd {base_path} && git pull && chmod +x setup.sh && ./setup.sh")
+
+ with open(os.path.join(base_path, "VERSION")) as f:
+ new_version = f.read().strip()
+ new_version = int("".join(new_version.split(".")))
+
+ if new_version == latest_version:
+ bt.logging.info("Updated successfully. Restarting...")
+ os.kill(os.getpid(), signal.SIGINT)
+ else:
+ bt.logging.error("Update failed. Manual update required.")
+ except Exception as e:
+ bt.logging.error(f"Update check failed: {e}")
diff --git a/bitmind/base/__init__.py b/bitmind/base/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/bitmind/base/bm_dendrite.py b/bitmind/base/bm_dendrite.py
deleted file mode 100644
index 04d407d8..00000000
--- a/bitmind/base/bm_dendrite.py
+++ /dev/null
@@ -1,372 +0,0 @@
-import asyncio
-import time
-import uuid
-from typing import Any, AsyncGenerator, Optional, Union, Type, List
-
-import aiohttp
-from bittensor_wallet import Keypair, Wallet
-
-from bittensor.core.axon import Axon
-from bittensor.core.chain_data import AxonInfo
-from bittensor.core.stream import StreamingSynapse
-from bittensor.core.synapse import Synapse
-from bittensor.utils.btlogging import logging
-from bittensor.core.dendrite import Dendrite
-
-class BMDendrite(Dendrite):
- """
- Enhanced Dendrite implementation with improved connection pooling and resilience.
-
- This class extends the standard Dendrite to provide better handling of concurrent
- connections, automatic retries for common network issues, and batch processing
- of multiple axon queries to prevent resource exhaustion.
-
- Args:
- wallet (Optional[Union["Wallet", "Keypair"]]): The wallet or keypair used for
- signing messages. Same as parent Dendrite.
- max_connections (int): Maximum number of total concurrent connections.
- max_connections_per_axon (int): Maximum number of concurrent connections per host.
- retry_attempts (int): Number of retry attempts for recoverable errors.
- batch_size (int): Number of axons to query in a single batch when running async.
- keepalive_timeout (float): How long to keep connections alive in the pool (seconds).
- """
-
- def __init__(
- self,
- wallet: Optional[Union["Wallet", "Keypair"]] = None,
- max_connections: int = 100,
- max_connections_per_axon: int = 8,
- retry_attempts: int = 2,
- batch_size: int = 20,
- keepalive_timeout: float = 15.0
- ):
- super().__init__(wallet=wallet)
-
- self.max_connections = max_connections
- self.max_connections_per_axon = max_connections_per_axon
- self.retry_attempts = retry_attempts
- self.batch_size = batch_size
- self.keepalive_timeout = keepalive_timeout
-
- self._session = None
-
- self._connection_metrics = {
- "total_requests": 0,
- "retried_requests": 0,
- "failed_requests": 0,
- "successful_requests": 0,
- }
-
- @property
- async def session(self) -> aiohttp.ClientSession:
- """
- An asynchronous property that provides access to the internal aiohttp client session
- with improved connection pooling.
-
- Returns:
- aiohttp.ClientSession: The active aiohttp client session instance with custom connection pooling.
- """
- if self._session is None:
- connector = aiohttp.TCPConnector(
- limit=self.max_connections,
- limit_per_host=self.max_connections_per_axon,
- force_close=False,
- enable_cleanup_closed=True,
- keepalive_timeout=self.keepalive_timeout
- )
-
- self._session = aiohttp.ClientSession(
- connector=connector,
- timeout=aiohttp.ClientTimeout(
- total=None,
- connect=5.0,
- sock_connect=5.0,
- sock_read=10.0
- ),
- raise_for_status=False # handle HTTP status errors within the class
- )
- return self._session
-
- async def forward(
- self,
- axons: Union[list[Union["AxonInfo", "Axon"]], Union["AxonInfo", "Axon"]],
- synapse: "Synapse" = Synapse(),
- timeout: float = 12,
- deserialize: bool = True,
- run_async: bool = True,
- streaming: bool = False,
- ) -> list[Union["AsyncGenerator[Any, Any]", "Synapse", "StreamingSynapse"]]:
- """
- Enhanced forward method with batch processing and improved error handling.
-
- This implementation processes axons in batches when running asynchronously to prevent
- overwhelming network resources and connection pools.
-
- Args:
- axons: Target axons to query (single axon or list of axons)
- synapse: The Synapse object to send
- timeout: Maximum time to wait for a response
- deserialize: Whether to deserialize the response
- run_async: Whether to run queries concurrently
- streaming: Whether the response is expected as a stream
-
- Returns:
- Response from axons (single response or list of responses)
- """
- is_list = True
- if not isinstance(axons, list):
- is_list = False
- axons = [axons]
-
- is_streaming_subclass = issubclass(synapse.__class__, StreamingSynapse)
- if streaming != is_streaming_subclass:
- logging.warning(
- f"Argument streaming is {streaming} while issubclass(synapse, StreamingSynapse) is {synapse.__class__.__name__}. This may cause unexpected behavior."
- )
- streaming = is_streaming_subclass or streaming
-
- async def query_all_axons(
- is_stream: bool,
- ) -> Union["AsyncGenerator[Any, Any]", "Synapse", "StreamingSynapse"]:
- """Query all axons with improved connection handling."""
-
- async def single_axon_response_with_retry(
- target_axon: Union["AxonInfo", "Axon"],
- retries: int = 0
- ) -> Union["AsyncGenerator[Any, Any]", "Synapse", "StreamingSynapse"]:
- """Process a single axon with retry logic for connection errors."""
- self._connection_metrics["total_requests"] += 1
- try:
- if is_stream:
- # If in streaming mode, return the async_generator
- result = self.call_stream(
- target_axon=target_axon,
- synapse=synapse.model_copy(), # type: ignore
- timeout=timeout,
- deserialize=deserialize,
- )
- self._connection_metrics["successful_requests"] += 1
- return result
- else:
- # If not in streaming mode, simply call the axon and get the response.
- result = await self.call(
- target_axon=target_axon,
- synapse=synapse.model_copy(), # type: ignore
- timeout=timeout,
- deserialize=deserialize,
- )
- self._connection_metrics["successful_requests"] += 1
- return result
- except (aiohttp.ClientOSError, ConnectionResetError, aiohttp.ServerDisconnectedError) as e:
- # Retry on common network/connection errors
- error_str = str(e)
- is_retryable = (
- "Broken pipe" in error_str or
- "Connection reset" in error_str or
- "Server disconnected" in error_str
- )
-
- if retries < self.retry_attempts and is_retryable:
- backoff_time = 0.1 * (2 ** retries)
- logging.debug(
- f"Connection error to {target_axon.ip}:{target_axon.port}, "
- f"retrying in {backoff_time:.2f}s ({retries+1}/{self.retry_attempts})"
- )
- self._connection_metrics["retried_requests"] += 1
- await asyncio.sleep(backoff_time)
- return await single_axon_response_with_retry(target_axon, retries + 1)
-
- self._connection_metrics["failed_requests"] += 1
- raise
-
- if not run_async:
- return [
- await single_axon_response_with_retry(target_axon) for target_axon in axons
- ]
-
- all_responses = []
- for i in range(0, len(axons), self.batch_size):
- batch = axons[i:i+self.batch_size]
- batch_responses = await asyncio.gather(
- *(single_axon_response_with_retry(target_axon) for target_axon in batch),
- return_exceptions=True # Don't let one failure block others
- )
-
- # Process any exceptions that were captured
- for j, response in enumerate(batch_responses):
- if isinstance(response, Exception):
- failed_synapse = synapse.model_copy()
- target_axon = batch[j]
- failed_synapse = self.preprocess_synapse_for_request(
- target_axon, failed_synapse, timeout
- )
- failed_synapse = self.process_error_message(
- failed_synapse,
- failed_synapse.__class__.__name__,
- response
- )
- batch_responses[j] = failed_synapse
-
- all_responses.extend(batch_responses)
-
- return all_responses
-
- responses = await query_all_axons(streaming)
- return responses[0] if len(responses) == 1 and not is_list else responses
-
- async def call(
- self,
- target_axon: Union["AxonInfo", "Axon"],
- synapse: "Synapse" = Synapse(),
- timeout: float = 12.0,
- deserialize: bool = True,
- ) -> "Synapse":
- """
- Enhanced call method with improved error handling for connection issues.
-
- Args:
- target_axon: The target axon to query
- synapse: The Synapse object to send
- timeout: Maximum time to wait for a response
- deserialize: Whether to deserialize the response
-
- Returns:
- The response Synapse object
- """
-
- start_time = time.time()
- target_axon = (
- target_axon.info() if isinstance(target_axon, Axon) else target_axon
- )
-
- request_name = synapse.__class__.__name__
- url = self._get_endpoint_url(target_axon, request_name=request_name)
-
- synapse = self.preprocess_synapse_for_request(target_axon, synapse, timeout)
-
- try:
- self._log_outgoing_request(synapse)
-
- try:
- async with (await self.session).post(
- url=url,
- headers=synapse.to_headers(),
- json=synapse.model_dump(),
- timeout=aiohttp.ClientTimeout(total=timeout),
- ) as response:
- json_response = await response.json()
- self.process_server_response(response, json_response, synapse)
- except aiohttp.ClientPayloadError as e:
- if "Response payload is not completed" in str(e):
- synapse.dendrite.status_code = "499"
- synapse.dendrite.status_message = f"Incomplete response payload: {str(e)}"
- else:
- raise
- except aiohttp.ClientOSError as e:
- if "Broken pipe" in str(e):
- synapse.dendrite.status_code = "503"
- synapse.dendrite.status_message = f"Connection broken: {str(e)}"
- else:
- raise
-
- synapse.dendrite.process_time = str(time.time() - start_time)
-
- except Exception as e:
- synapse = self.process_error_message(synapse, request_name, e)
-
- finally:
- self._log_incoming_response(synapse)
- self.synapse_history.append(Synapse.from_headers(synapse.to_headers()))
- return synapse.deserialize() if deserialize else synapse
-
- async def call_stream(
- self,
- target_axon: Union["AxonInfo", "Axon"],
- synapse: "StreamingSynapse" = Synapse(),
- timeout: float = 12.0,
- deserialize: bool = True,
- ) -> "AsyncGenerator[Any, Any]":
- """
- Enhanced call_stream method for streaming responses with improved error handling.
-
- Args:
- target_axon: The target axon to query
- synapse: The Synapse object to send
- timeout: Maximum time to wait for initial response
- deserialize: Whether to deserialize the response
-
- Yields:
- Response chunks from the streaming endpoint
- """
- start_time = time.time()
- target_axon = (
- target_axon.info() if isinstance(target_axon, Axon) else target_axon
- )
-
- request_name = synapse.__class__.__name__
- endpoint = (
- f"0.0.0.0:{str(target_axon.port)}"
- if target_axon.ip == str(self.external_ip)
- else f"{target_axon.ip}:{str(target_axon.port)}"
- )
- url = f"http://{endpoint}/{request_name}"
-
- synapse = self.preprocess_synapse_for_request(target_axon, synapse, timeout)
-
- try:
- self._log_outgoing_request(synapse)
- stream_timeout = aiohttp.ClientTimeout(
- total=None,
- connect=10.0,
- sock_connect=10.0,
- sock_read=timeout
- )
-
- async with (await self.session).post(
- url,
- headers=synapse.to_headers(),
- json=synapse.model_dump(),
- timeout=stream_timeout,
- ) as response:
- try:
- async for chunk in synapse.process_streaming_response(response):
- yield chunk
- except (aiohttp.ClientPayloadError, aiohttp.ClientOSError) as e:
- error_msg = str(e)
- if "Broken pipe" in error_msg or "incomplete" in error_msg.lower():
- logging.warning(f"Streaming interrupted: {error_msg}")
- # The stream was interrupted, but we might have received partial data, so continue
-
- json_response = synapse.extract_response_json(response)
- self.process_server_response(response, json_response, synapse)
-
- synapse.dendrite.process_time = str(time.time() - start_time)
-
- except Exception as e:
- synapse = self.process_error_message(synapse, request_name, e)
-
- finally:
- self._log_incoming_response(synapse)
- self.synapse_history.append(Synapse.from_headers(synapse.to_headers()))
- if deserialize:
- yield synapse.deserialize()
- else:
- yield synapse
-
- def get_connection_metrics(self) -> dict:
- """
- Get metrics about connection usage and errors.
-
- Returns:
- dict: A dictionary containing connection metrics
- """
- return self._connection_metrics.copy()
-
- def reset_connection_metrics(self) -> None:
- """Reset all connection metrics counters"""
- self._connection_metrics = {
- "total_requests": 0,
- "retried_requests": 0,
- "failed_requests": 0,
- "successful_requests": 0,
- }
diff --git a/bitmind/base/miner.py b/bitmind/base/miner.py
deleted file mode 100644
index a2228c30..00000000
--- a/bitmind/base/miner.py
+++ /dev/null
@@ -1,285 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import time
-import asyncio
-import threading
-import argparse
-import traceback
-import typing
-
-import bittensor as bt
-
-from bitmind.base.neuron import BaseNeuron
-from bitmind.utils.config import add_miner_args
-
-from typing import Union
-
-
-class BaseMinerNeuron(BaseNeuron):
- """
- Base class for Bittensor miners.
- """
-
- neuron_type: str = "MinerNeuron"
-
- @classmethod
- def add_args(cls, parser: argparse.ArgumentParser):
- super().add_args(parser)
- add_miner_args(cls, parser)
-
- def __init__(self, config=None):
- super().__init__(config=config)
-
- # Warn if allowing incoming requests from anyone.
- if not self.config.blacklist.force_validator_permit:
- bt.logging.warning(
- "You are allowing non-validators to send requests to your miner. This is a security risk."
- )
- if self.config.blacklist.allow_non_registered:
- bt.logging.warning(
- "You are allowing non-registered entities to send requests to your miner. This is a security risk."
- )
-
- # attach miner-specific functions in subclass __init__
- self.axon = bt.axon(wallet=self.wallet, config=self.config() if callable(self.config) else self.config)
-
- # Instantiate runners
- self.should_exit: bool = False
- self.is_running: bool = False
- self.thread: Union[threading.Thread, None] = None
- self.lock = asyncio.Lock()
-
- def run(self):
- """
- Initiates and manages the main loop for the miner on the Bittensor network. The main loop handles graceful shutdown on keyboard interrupts and logs unforeseen errors.
-
- This function performs the following primary tasks:
- 1. Check for registration on the Bittensor network.
- 2. Starts the miner's axon, making it active on the network.
- 3. Periodically resynchronizes with the chain; updating the metagraph with the latest network state and setting weights.
-
- The miner continues its operations until `should_exit` is set to True or an external interruption occurs.
- During each epoch of its operation, the miner waits for new blocks on the Bittensor network, updates its
- knowledge of the network (metagraph), and sets its weights. This process ensures the miner remains active
- and up-to-date with the network's latest state.
-
- Note:
- - The function leverages the global configurations set during the initialization of the miner.
- - The miner's axon serves as its interface to the Bittensor network, handling incoming and outgoing requests.
-
- Raises:
- KeyboardInterrupt: If the miner is stopped by a manual interruption.
- Exception: For unforeseen errors during the miner's operation, which are logged for diagnosis.
- """
-
- # Check that miner is registered on the network.
- self.sync()
-
- # Serve passes the axon information to the network + netuid we are hosting on.
- # This will auto-update if the axon port of external ip have changed.
- bt.logging.info(
- f"Serving miner axon {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}"
- )
- self.axon.serve(netuid=self.config.netuid, subtensor=self.subtensor)
-
- # Start starts the miner's axon, making it active on the network.
- self.axon.start()
-
- bt.logging.info(f"Miner starting at block: {self.block}")
-
- # This loop maintains the miner's operations until intentionally stopped.
- try:
- while not self.should_exit:
- while (
- self.block - self.metagraph.last_update[self.uid]
- < self.config.neuron.epoch_length
- ):
- # Wait before checking again.
- time.sleep(1)
-
- # Check if we should exit.
- if self.should_exit:
- break
-
- # Sync metagraph and potentially set weights.
- self.sync()
- self.step += 1
- time.sleep(60)
-
- # If someone intentionally stops the miner, it'll safely terminate operations.
- except KeyboardInterrupt:
- self.axon.stop()
- bt.logging.success("Miner killed by keyboard interrupt.")
- exit()
-
- # In case of unforeseen errors, the miner will log the error and continue operations.
- except Exception as e:
- bt.logging.error(traceback.format_exc())
-
- def run_in_background_thread(self):
- """
- Starts the miner's operations in a separate background thread.
- This is useful for non-blocking operations.
- """
- if not self.is_running:
- bt.logging.debug("Starting miner in background thread.")
- self.should_exit = False
- self.thread = threading.Thread(target=self.run, daemon=True)
- self.thread.start()
- self.is_running = True
- bt.logging.debug("Started")
-
- def stop_run_thread(self):
- """
- Stops the miner's operations that are running in the background thread.
- """
- if self.is_running:
- bt.logging.debug("Stopping miner in background thread.")
- self.should_exit = True
- if self.thread is not None:
- self.thread.join(5)
- self.is_running = False
- bt.logging.debug("Stopped")
-
- def __enter__(self):
- """
- Starts the miner's operations in a background thread upon entering the context.
- This method facilitates the use of the miner in a 'with' statement.
- """
- self.run_in_background_thread()
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- """
- Stops the miner's background operations upon exiting the context.
- This method facilitates the use of the miner in a 'with' statement.
-
- Args:
- exc_type: The type of the exception that caused the context to be exited.
- None if the context was exited without an exception.
- exc_value: The instance of the exception that caused the context to be exited.
- None if the context was exited without an exception.
- traceback: A traceback object encoding the stack trace.
- None if the context was exited without an exception.
- """
- self.stop_run_thread()
-
- def resync_metagraph(self):
- """Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph."""
- bt.logging.info("resync_metagraph()")
-
- # Sync the metagraph.
- self.metagraph.sync(subtensor=self.subtensor)
-
- async def blacklist(
- self, synapse: bt.Synapse
- ) -> typing.Tuple[bool, str]:
- """
- Determines whether an incoming request should be blacklisted and thus ignored. Your implementation should
- define the logic for blacklisting requests based on your needs and desired security parameters.
-
- Blacklist runs before the synapse data has been deserialized (i.e. before synapse.data is available).
- The synapse is instead contructed via the headers of the request. It is important to blacklist
- requests before they are deserialized to avoid wasting resources on requests that will be ignored.
-
- Args:
- synapse (bt.Synapse): A synapse object constructed from the headers of the incoming request.
-
- Returns:
- Tuple[bool, str]: A tuple containing a boolean indicating whether the synapse's hotkey is blacklisted,
- and a string providing the reason for the decision.
-
- This function is a security measure to prevent resource wastage on undesired requests. It should be enhanced
- to include checks against the metagraph for entity registration, validator status, and sufficient stake
- before deserialization of synapse data to minimize processing overhead.
-
- Example blacklist logic:
- - Reject if the hotkey is not a registered entity within the metagraph.
- - Consider blacklisting entities that are not validators or have insufficient stake.
-
- In practice it would be wise to blacklist requests from entities that are not validators, or do not have
- enough stake. This can be checked via metagraph.S and metagraph.validator_permit. You can always attain
- the uid of the sender via a metagraph.hotkeys.index( synapse.dendrite.hotkey ) call.
-
- Otherwise, allow the request to be processed further.
- """
- if synapse.dendrite is None or synapse.dendrite.hotkey is None:
- bt.logging.warning("Received a request without a dendrite or hotkey.")
- return True, "Missing dendrite or hotkey"
-
- # TODO(developer): Define how miners should blacklist requests.
- uid = self.metagraph.hotkeys.index(synapse.dendrite.hotkey)
- if (
- not self.config.blacklist.allow_non_registered
- and synapse.dendrite.hotkey not in self.metagraph.hotkeys
- ):
- # Ignore requests from un-registered entities.
- bt.logging.trace(
- f"Blacklisting un-registered hotkey {synapse.dendrite.hotkey}"
- )
- return True, "Unrecognized hotkey"
-
- if self.config.blacklist.force_validator_permit:
- # If the config is set to force validator permit, then we should only allow requests from validators.
- if not self.metagraph.validator_permit[uid] or self.metagraph.S[uid] < 30000:
- bt.logging.warning(
- f"Blacklisting a request from non-validator hotkey {synapse.dendrite.hotkey}"
- )
- return True, "Non-validator hotkey"
-
- bt.logging.trace(
- f"Not Blacklisting recognized hotkey {synapse.dendrite.hotkey}"
- )
- return False, "Hotkey recognized!"
-
- async def priority(self, synapse: bt.Synapse) -> float:
- """
- The priority function determines the order in which requests are handled. More valuable or higher-priority
- requests are processed before others. You should design your own priority mechanism with care.
-
- This implementation assigns priority to incoming requests based on the calling entity's stake in the metagraph.
-
- Args:
- synapse (bt.Synapse): The synapse object that contains metadata about the incoming request.
-
- Returns:
- float: A priority score derived from the stake of the calling entity.
-
- Miners may recieve messages from multiple entities at once. This function determines which request should be
- processed first. Higher values indicate that the request should be processed first. Lower values indicate
- that the request should be processed later.
-
- Example priority logic:
- - A higher stake results in a higher priority value.
- """
- if synapse.dendrite is None or synapse.dendrite.hotkey is None:
- bt.logging.warning("Received a request without a dendrite or hotkey.")
- return 0.0
-
- # TODO(developer): Define how miners should prioritize requests.
- caller_uid = self.metagraph.hotkeys.index(
- synapse.dendrite.hotkey
- ) # Get the caller index.
-
- prirority = float(
- self.metagraph.S[caller_uid]
- ) # Return the stake as the priority.
- bt.logging.trace(
- f"Prioritizing {synapse.dendrite.hotkey} with value: ", prirority
- )
- return prirority
diff --git a/bitmind/base/neuron.py b/bitmind/base/neuron.py
deleted file mode 100644
index 247afa16..00000000
--- a/bitmind/base/neuron.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import copy
-import typing
-
-import bittensor as bt
-
-from abc import ABC, abstractmethod
-
-# Sync calls set weights and also resyncs the metagraph.
-from bitmind.utils.config import check_config, add_args, config
-from bitmind.utils.misc import ttl_get_block
-from bitmind import __spec_version__ as spec_version
-from bitmind.utils.mock import MockSubtensor, MockMetagraph
-
-
-class BaseNeuron(ABC):
- """
- Base class for Bittensor miners. This class is abstract and should be inherited by a subclass. It contains the core logic for all neurons; validators and miners.
-
- In addition to creating a wallet, subtensor, and metagraph, this class also handles the synchronization of the network state via a basic checkpointing mechanism based on epoch length.
- """
-
- neuron_type: str = "BaseNeuron"
-
- @classmethod
- def check_config(cls, config: "bt.Config"):
- check_config(cls, config)
-
- @classmethod
- def add_args(cls, parser):
- add_args(cls, parser)
-
- @classmethod
- def config(cls):
- return config(cls)
-
- subtensor: "bt.subtensor"
- wallet: "bt.wallet"
- metagraph: "bt.metagraph"
- spec_version: int = spec_version
-
- @property
- def block(self):
- return ttl_get_block(self)
-
- def __init__(self, config=None):
- base_config = copy.deepcopy(config or BaseNeuron.config())
- self.config = self.config()
- self.config.merge(base_config)
- self.check_config(self.config)
-
- # Set up logging with the provided configuration.
- bt.logging.set_config(config=self.config.logging)
-
- # If a gpu is required, set the device to cuda:N (e.g. cuda:0)
- self.device = self.config.neuron.device
-
- # Log the configuration for reference.
- bt.logging.info(self.config)
-
- # Build Bittensor objects
- # These are core Bittensor classes to interact with the network.
- bt.logging.info("Setting up bittensor objects.")
-
- # The wallet holds the cryptographic key pairs for the miner.
- if self.config.mock:
- self.wallet = bt.MockWallet(config=self.config)
- self.subtensor = MockSubtensor(
- self.config.netuid, wallet=self.wallet
- )
- self.metagraph = MockMetagraph(
- self.config.netuid, subtensor=self.subtensor
- )
- else:
- self.wallet = bt.wallet(config=self.config)
- self.subtensor = bt.subtensor(config=self.config)
- self.metagraph = self.subtensor.metagraph(self.config.netuid)
-
- bt.logging.info(f"Wallet: {self.wallet}")
- bt.logging.info(f"Subtensor: {self.subtensor}")
- bt.logging.info(f"Metagraph: {self.metagraph}")
-
- # Check if the miner is registered on the Bittensor network before proceeding further.
- self.check_registered()
-
- # Each miner gets a unique identity (UID) in the network for differentiation.
- self.uid = self.metagraph.hotkeys.index(
- self.wallet.hotkey.ss58_address
- )
- bt.logging.info(
- f"Running neuron on subnet: {self.config.netuid} with uid {self.uid} using network: {self.subtensor.chain_endpoint}"
- )
- self.step = 0
-
- @abstractmethod
- def run(self):
- ...
-
- def sync(self):
- """
- Wrapper for synchronizing the state of the network for the given miner or validator.
- """
- # Ensure miner or validator hotkey is still registered on the network.
- self.check_registered()
-
- if self.should_sync_metagraph():
- self.resync_metagraph()
-
- if self.should_set_weights():
- self.set_weights()
-
- # Always save state.
- self.save_state()
-
- def check_registered(self):
- # --- Check for registration.
- if not self.subtensor.is_hotkey_registered(
- netuid=self.config.netuid,
- hotkey_ss58=self.wallet.hotkey.ss58_address,
- ):
- bt.logging.error(
- f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}."
- f" Please register the hotkey using `btcli subnets register` before trying again"
- )
- exit()
-
- def should_sync_metagraph(self):
- """
- Check if enough epoch blocks have elapsed since the last checkpoint to sync.
- """
- return (
- self.block - self.metagraph.last_update[self.uid]
- ) > self.config.neuron.epoch_length
-
- def should_set_weights(self) -> bool:
- # Don't set weights on initialization.
- if self.step == 0:
- return False
-
- # Check if enough epoch blocks have elapsed since the last epoch.
- if self.config.neuron.disable_set_weights:
- return False
-
- # Define appropriate logic for when set weights.
- return (
- (self.block - self.metagraph.last_update[self.uid])
- > self.config.neuron.epoch_length
- and self.neuron_type != "MinerNeuron"
- ) # don't set weights if you're a miner
-
- def save_state(self):
- bt.logging.warning(
- "save_state() not implemented for this neuron. You can implement this function to save model checkpoints or other useful data."
- )
-
- def load_state(self):
- bt.logging.warning(
- "load_state() not implemented for this neuron. You can implement this function to load model checkpoints or other useful data."
- )
diff --git a/bitmind/base/utils/weight_utils.py b/bitmind/base/utils/weight_utils.py
deleted file mode 100644
index c009d253..00000000
--- a/bitmind/base/utils/weight_utils.py
+++ /dev/null
@@ -1,209 +0,0 @@
-import numpy as np
-from typing import Tuple, List, Union, Any
-import bittensor
-from numpy import ndarray, dtype, floating, complexfloating
-
-U32_MAX = 4294967295
-U16_MAX = 65535
-
-
-def normalize_max_weight(
- x: np.ndarray, limit: float = 0.1
-) -> np.ndarray:
- r"""Normalizes the numpy array x so that sum(x) = 1 and the max value is not greater than the limit.
- Args:
- x (:obj:`np.ndarray`):
- Array to be max_value normalized.
- limit: float:
- Max value after normalization.
- Returns:
- y (:obj:`np.ndarray`):
- Normalized x array.
- """
- epsilon = 1e-7 # For numerical stability after normalization
-
- weights = x.copy()
- values = np.sort(weights)
-
- if x.sum() == 0 or len(x) * limit <= 1:
- return np.ones_like(x) / x.size
- else:
- estimation = values / values.sum()
-
- if estimation.max() <= limit:
- return weights / weights.sum()
-
- # Find the cumulative sum and sorted array
- cumsum = np.cumsum(estimation, 0)
-
- # Determine the index of cutoff
- estimation_sum = np.array(
- [(len(values) - i - 1) * estimation[i] for i in range(len(values))]
- )
- n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum()
-
- # Determine the cutoff based on the index
- cutoff_scale = (limit * cumsum[n_values - 1] - epsilon) / (
- 1 - (limit * (len(estimation) - n_values))
- )
- cutoff = cutoff_scale * values.sum()
-
- # Applying the cutoff
- weights[weights > cutoff] = cutoff
-
- y = weights / weights.sum()
-
- return y
-
-
-def convert_weights_and_uids_for_emit(
- uids: np.ndarray, weights: np.ndarray
-) -> Tuple[List[int], List[int]]:
- r"""Converts weights into integer u32 representation that sum to MAX_INT_WEIGHT.
- Args:
- uids (:obj:`np.ndarray,`):
- Array of uids as destinations for passed weights.
- weights (:obj:`np.ndarray,`):
- Array of weights.
- Returns:
- weight_uids (List[int]):
- Uids as a list.
- weight_vals (List[int]):
- Weights as a list.
- """
- # Checks.
- uids = np.asarray(uids)
- weights = np.asarray(weights)
-
- # Get non-zero weights and corresponding uids
- non_zero_weights = weights[weights > 0]
- non_zero_weight_uids = uids[weights > 0]
-
- # Debugging information
- bittensor.logging.debug(f"weights: {weights}")
- bittensor.logging.debug(f"non_zero_weights: {non_zero_weights}")
- bittensor.logging.debug(f"uids: {uids}")
- bittensor.logging.debug(f"non_zero_weight_uids: {non_zero_weight_uids}")
-
- if np.min(weights) < 0:
- raise ValueError(
- "Passed weight is negative cannot exist on chain {}".format(weights)
- )
- if np.min(uids) < 0:
- raise ValueError("Passed uid is negative cannot exist on chain {}".format(uids))
- if len(uids) != len(weights):
- raise ValueError(
- "Passed weights and uids must have the same length, got {} and {}".format(
- len(uids), len(weights)
- )
- )
- if np.sum(weights) == 0:
- bittensor.logging.debug("nothing to set on chain")
- return [], [] # Nothing to set on chain.
- else:
- max_weight = float(np.max(weights))
- weights = [
- float(value) / max_weight for value in weights
- ] # max-upscale values (max_weight = 1).
- bittensor.logging.debug(f"setting on chain max: {max_weight} and weights: {weights}")
-
- weight_vals = []
- weight_uids = []
- for i, (weight_i, uid_i) in enumerate(list(zip(weights, uids))):
- uint16_val = round(
- float(weight_i) * int(U16_MAX)
- ) # convert to int representation.
-
- # Filter zeros
- if uint16_val != 0: # Filter zeros
- weight_vals.append(uint16_val)
- weight_uids.append(uid_i)
- bittensor.logging.debug(f"final params: {weight_uids} : {weight_vals}")
- return weight_uids, weight_vals
-
-
-def process_weights_for_netuid(
- uids,
- weights: np.ndarray,
- netuid: int,
- subtensor: "bittensor.subtensor",
- metagraph: "bittensor.metagraph" = None,
- exclude_quantile: int = 0,
-) -> Union[tuple[ndarray[Any, dtype[Any]], Union[
- Union[ndarray[Any, dtype[floating[Any]]], ndarray[Any, dtype[complexfloating[Any, Any]]]], Any]], tuple[
- ndarray[Any, dtype[Any]], ndarray], tuple[Any, ndarray]]:
- bittensor.logging.debug("process_weights_for_netuid()")
- bittensor.logging.debug("weights", weights)
- bittensor.logging.debug("netuid", netuid)
- bittensor.logging.debug("subtensor", subtensor)
- bittensor.logging.debug("metagraph", metagraph)
-
- # Get latest metagraph from chain if metagraph is None.
- if metagraph is None:
- metagraph = subtensor.metagraph(netuid)
-
- # Cast weights to floats.
- if not isinstance(weights, np.ndarray) or weights.dtype != np.float32:
- weights = weights.astype(np.float32)
-
- # Network configuration parameters from an subtensor.
- # These parameters determine the range of acceptable weights for each neuron.
- quantile = exclude_quantile / U16_MAX
- min_allowed_weights = subtensor.min_allowed_weights(netuid=netuid)
- max_weight_limit = subtensor.max_weight_limit(netuid=netuid)
- bittensor.logging.debug("quantile", quantile)
- bittensor.logging.debug("min_allowed_weights", min_allowed_weights)
- bittensor.logging.debug("max_weight_limit", max_weight_limit)
-
- # Find all non zero weights.
- non_zero_weight_idx = np.argwhere(weights > 0).squeeze()
- if non_zero_weight_idx.ndim == 0:
- non_zero_weight_idx = non_zero_weight_idx.reshape((1,))
-
- non_zero_weight_uids = uids[non_zero_weight_idx]
- non_zero_weights = weights[non_zero_weight_idx]
- if non_zero_weights.size == 0 or metagraph.n < min_allowed_weights:
- bittensor.logging.warning("No non-zero weights returning all ones.")
- final_weights = np.ones(metagraph.n) / metagraph.n
- bittensor.logging.debug("final_weights", final_weights)
- return np.arange(len(final_weights)), final_weights
-
- elif non_zero_weights.size < min_allowed_weights:
- bittensor.logging.warning(
- "No non-zero weights less then min allowed weight, returning all ones."
- )
- weights = (
- np.ones(metagraph.n) * 1e-5
- ) # creating minimum even non-zero weights
- weights[non_zero_weight_idx] += non_zero_weights
- bittensor.logging.debug("final_weights", weights)
- normalized_weights = normalize_max_weight(
- x=weights, limit=max_weight_limit
- )
- return np.arange(len(normalized_weights)), normalized_weights
-
- bittensor.logging.debug("non_zero_weights", non_zero_weights)
-
- # Compute the exclude quantile and find the weights in the lowest quantile
- max_exclude = max(0, len(non_zero_weights) - min_allowed_weights) / len(
- non_zero_weights
- )
- exclude_quantile = min([quantile, max_exclude])
- lowest_quantile = np.quantile(non_zero_weights, exclude_quantile)
- bittensor.logging.debug("max_exclude", max_exclude)
- bittensor.logging.debug("exclude_quantile", exclude_quantile)
- bittensor.logging.debug("lowest_quantile", lowest_quantile)
-
- # Exclude all weights below the allowed quantile.
- non_zero_weight_uids = non_zero_weight_uids[lowest_quantile <= non_zero_weights]
- non_zero_weights = non_zero_weights[lowest_quantile <= non_zero_weights]
- bittensor.logging.debug("non_zero_weight_uids", non_zero_weight_uids)
- bittensor.logging.debug("non_zero_weights", non_zero_weights)
-
- # Normalize weights and return.
- normalized_weights = normalize_max_weight(
- x=non_zero_weights, limit=max_weight_limit
- )
- bittensor.logging.debug("final_weights", normalized_weights)
-
- return non_zero_weight_uids, normalized_weights
diff --git a/bitmind/base/validator.py b/bitmind/base/validator.py
deleted file mode 100644
index 257f28dc..00000000
--- a/bitmind/base/validator.py
+++ /dev/null
@@ -1,479 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# TODO(developer): Set your name
-# Copyright © 2023
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-from traceback import print_exception
-from typing import List, Union
-from collections import deque
-import bittensor as bt
-import numpy as np
-import threading
-import argparse
-import asyncio
-import joblib
-import time
-import copy
-import os
-
-from bitmind.validator.miner_performance_tracker import MinerPerformanceTracker
-from bitmind.utils.config import add_validator_args
-from bitmind.utils.mock import MockDendrite
-from bitmind.base.neuron import BaseNeuron
-from bitmind.base.bm_dendrite import BMDendrite
-from bitmind.base.utils.weight_utils import (
- process_weights_for_netuid,
- convert_weights_and_uids_for_emit,
-)
-
-
-class BaseValidatorNeuron(BaseNeuron):
- """
- Base class for Bittensor validators. Your validator should inherit from this class.
- """
-
- neuron_type: str = "ValidatorNeuron"
-
- @classmethod
- def add_args(cls, parser: argparse.ArgumentParser):
- super().add_args(parser)
- add_validator_args(cls, parser)
-
- def __init__(self, config=None):
- super().__init__(config=config)
-
- self.performance_trackers = {
- 'image': None,
- 'video': None
- }
-
- self.image_history_cache_path = os.path.join(
- self.config.neuron.full_path, "image_miner_performance_tracker.pkl")
- self.video_history_cache_path = os.path.join(
- self.config.neuron.full_path, "video_miner_performance_tracker.pkl")
-
- # Save a copy of the hotkeys to local memory.
- self.hotkeys = copy.deepcopy(self.metagraph.hotkeys)
-
- # Dendrite lets us send messages to other nodes (axons) in the network.
- if self.config.mock:
- self.dendrite = MockDendrite(wallet=self.wallet)
- else:
- self.dendrite = BMDendrite(self.wallet, batch_size=25, max_connections_per_axon=2)
- bt.logging.info(f"Dendrite: {self.dendrite}")
-
- # Set up initial scoring weights for validation
- bt.logging.info("Building validation weights.")
- self.scores = np.zeros(self.metagraph.n, dtype=np.float32) # in caese no saved scores available
- self.load_state()
-
- # Init sync with the network. Updates the metagraph.
- self.sync()
-
- # Serve axon to enable external connections.
- if not self.config.neuron.axon_off:
- self.serve_axon()
- else:
- bt.logging.warning("axon off, not serving ip to chain.")
-
- # Create asyncio event loop to manage async tasks.
- self.loop = asyncio.get_event_loop()
-
- # Instantiate runners
- self.should_exit: bool = False
- self.is_running: bool = False
- self.thread: Union[threading.Thread, None] = None
- self.lock = asyncio.Lock()
-
- def serve_axon(self):
- """Serve axon to enable external connections."""
-
- bt.logging.info("serving ip to chain...")
- try:
- self.axon = bt.axon(wallet=self.wallet, config=self.config)
-
- try:
- self.subtensor.serve_axon(
- netuid=self.config.netuid,
- axon=self.axon,
- )
- bt.logging.info(
- f"Running validator {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}"
- )
- except Exception as e:
- bt.logging.error(f"Failed to serve Axon with exception: {e}")
- pass
-
- except Exception as e:
- bt.logging.error(f"Failed to create Axon initialize with exception: {e}")
- pass
-
- async def concurrent_forward(self):
- coroutines = [
- self.forward() for _ in range(self.config.neuron.num_concurrent_forwards)
- ]
- await asyncio.gather(*coroutines)
-
- def run(self):
- """
- Initiates and manages the main loop for the miner on the Bittensor network. The main loop handles graceful shutdown on keyboard interrupts and logs unforeseen errors.
-
- This function performs the following primary tasks:
- 1. Check for registration on the Bittensor network.
- 2. Continuously forwards queries to the miners on the network, rewarding their responses and updating the scores accordingly.
- 3. Periodically resynchronizes with the chain; updating the metagraph with the latest network state and setting weights.
-
- The essence of the validator's operations is in the forward function, which is called every step. The forward function is responsible for querying the network and scoring the responses.
-
- Note:
- - The function leverages the global configurations set during the initialization of the miner.
- - The miner's axon serves as its interface to the Bittensor network, handling incoming and outgoing requests.
-
- Raises:
- KeyboardInterrupt: If the miner is stopped by a manual interruption.
- Exception: For unforeseen errors during the miner's operation, which are logged for diagnosis.
- """
-
- # Check that validator is registered on the network.
- self.sync()
-
- bt.logging.info(f"Validator starting at block: {self.block}")
-
- # This loop maintains the validator's operations until intentionally stopped.
- try:
- while True:
- bt.logging.info(f"step({self.step}) block({self.block})")
-
- # Run multiple forwards concurrently.
- self.loop.run_until_complete(self.concurrent_forward())
-
- # Check if we should exit.
- if self.should_exit:
- break
-
- # Sync metagraph and potentially set weights.
- self.sync()
- time.sleep(60)
- self.step += 1
-
- # If someone intentionally stops the validator, it'll safely terminate operations.
- except KeyboardInterrupt:
- self.axon.stop()
- bt.logging.success("Validator killed by keyboard interrupt.")
- exit()
-
- # In case of unforeseen errors, the validator will log the error and continue operations.
- except Exception as err:
- bt.logging.error(f"Error during validation: {str(err)}")
- bt.logging.debug(str(print_exception(type(err), err, err.__traceback__)))
-
- def run_in_background_thread(self):
- """
- Starts the validator's operations in a background thread upon entering the context.
- This method facilitates the use of the validator in a 'with' statement.
- """
- if not self.is_running:
- bt.logging.debug("Starting validator in background thread.")
- self.should_exit = False
- self.thread = threading.Thread(target=self.run, daemon=True)
- self.thread.start()
- self.is_running = True
- bt.logging.debug("Started")
-
- def stop_run_thread(self):
- """
- Stops the validator's operations that are running in the background thread.
- """
- if self.is_running:
- bt.logging.debug("Stopping validator in background thread.")
- self.should_exit = True
- self.thread.join(5)
- self.is_running = False
- bt.logging.debug("Stopped")
-
- def __enter__(self):
- self.run_in_background_thread()
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- """
- Stops the validator's background operations upon exiting the context.
- This method facilitates the use of the validator in a 'with' statement.
-
- Args:
- exc_type: The type of the exception that caused the context to be exited.
- None if the context was exited without an exception.
- exc_value: The instance of the exception that caused the context to be exited.
- None if the context was exited without an exception.
- traceback: A traceback object encoding the stack trace.
- None if the context was exited without an exception.
- """
- if self.is_running:
- bt.logging.debug("Stopping validator in background thread.")
- self.should_exit = True
- self.thread.join(5)
- self.is_running = False
- bt.logging.debug("Stopped")
-
- def set_weights(self):
- """
- Sets the validator weights to the metagraph hotkeys based on the scores it has received from the miners.
- The weights determine the trust and incentive level the validator assigns to miner nodes on the network.
- """
-
- # Check if self.scores contains any NaN values and log a warning if it does.
- if np.isnan(self.scores).any():
- bt.logging.warning(
- f"Scores contain NaN values. This may be due to a lack of responses from miners, or a bug in your reward functions."
- )
-
- # Calculate the average reward for each uid across non-zero values.
- # Replace any NaN values with 0.
- # Compute the norm of the scores
- norm = np.linalg.norm(self.scores, ord=1, axis=0, keepdims=True)
-
- # Check if the norm is zero or contains NaN values
- if np.any(norm == 0) or np.isnan(norm).any():
- norm = np.ones_like(norm) # Avoid division by zero or NaN
-
- # Compute raw_weights safely
- raw_weights = self.scores / norm
-
- bt.logging.debug("raw_weights", raw_weights)
- bt.logging.debug("raw_weight_uids", str(self.metagraph.uids.tolist()))
- # Process the raw weights to final_weights via subtensor limitations.
- (
- processed_weight_uids,
- processed_weights,
- ) = process_weights_for_netuid(
- uids=self.metagraph.uids,
- weights=raw_weights,
- netuid=self.config.netuid,
- subtensor=self.subtensor,
- metagraph=self.metagraph,
- )
- bt.logging.debug("processed_weights", processed_weights)
- bt.logging.debug("processed_weight_uids", processed_weight_uids)
-
- # Convert to uint16 weights and uids.
- (
- uint_uids,
- uint_weights,
- ) = convert_weights_and_uids_for_emit(
- uids=processed_weight_uids, weights=processed_weights
- )
- bt.logging.debug("uint_weights", uint_weights)
- bt.logging.debug("uint_uids", uint_uids)
-
- # Set the weights on chain via our subtensor connection.
- result, msg = self.subtensor.set_weights(
- wallet=self.wallet,
- netuid=self.config.netuid,
- uids=uint_uids,
- weights=uint_weights,
- wait_for_finalization=False,
- wait_for_inclusion=False,
- version_key=self.spec_version,
- )
- if result is True:
- bt.logging.info("set_weights on chain successfully!")
- else:
- bt.logging.error("set_weights failed", msg)
-
- def resync_metagraph(self):
- """Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph."""
- bt.logging.info("resync_metagraph()")
-
- # Copies state of metagraph before syncing.
- previous_metagraph = copy.deepcopy(self.metagraph)
-
- # Sync the metagraph.
- self.metagraph.sync(subtensor=self.subtensor)
-
- for uid, hotkey in enumerate(self.hotkeys):
- if self.metagraph.validator_permit[uid] and self.metagraph.S[uid] > self.config.neuron.vpermit_tao_limit:
- self.scores[uid] = 0
-
- # Check if the metagraph axon info has changed.
- if previous_metagraph.axons == self.metagraph.axons:
- return
-
- bt.logging.info(
- "Metagraph updated, re-syncing hotkeys, dendrite pool and moving averages"
- )
- # Zero out all hotkeys that have been replaced.
- for uid, hotkey in enumerate(self.hotkeys):
- if hotkey != self.metagraph.hotkeys[uid]:
- self.scores[uid] = 0 # hotkey has been replaced
-
- # Check to see if the metagraph has changed size.
- # If so, we need to add new hotkeys and moving averages.
- if len(self.hotkeys) < len(self.metagraph.hotkeys):
- # Update the size of the moving average scores.
- new_moving_average = np.zeros((self.metagraph.n))
- min_len = min(len(self.hotkeys), len(self.scores))
- new_moving_average[:min_len] = self.scores[:min_len]
- self.scores = new_moving_average
-
- # Update the hotkeys.
- self.hotkeys = copy.deepcopy(self.metagraph.hotkeys)
-
- def update_scores(self, rewards: np.ndarray, uids: List[int]):
- """Performs exponential moving average on the scores based on the rewards received from the miners."""
-
- # Check if rewards contains NaN values.
- if np.isnan(rewards).any():
- bt.logging.warning(f"NaN values detected in rewards: {rewards}")
- # Replace any NaN values in rewards with 0.
- rewards = np.nan_to_num(rewards, nan=0)
-
- # Ensure rewards is a numpy array.
- rewards = np.asarray(rewards)
-
- # Check if `uids` is already a numpy array and copy it to avoid the warning.
- if isinstance(uids, np.ndarray):
- uids_array = uids.copy()
- else:
- uids_array = np.array(uids)
-
- # Handle edge case: If either rewards or uids_array is empty.
- if rewards.size == 0 or uids_array.size == 0:
- bt.logging.info(f"rewards: {rewards}, uids_array: {uids_array}")
- bt.logging.warning(
- "Either rewards or uids_array is empty. No updates will be performed."
- )
- return
-
- # Check if sizes of rewards and uids_array match.
- if rewards.size != uids_array.size:
- raise ValueError(
- f"Shape mismatch: rewards array of shape {rewards.shape} "
- f"cannot be broadcast to uids array of shape {uids_array.shape}"
- )
-
- # Compute forward pass rewards, assumes uids are mutually exclusive.
- # shape: [ metagraph.n ]
- scattered_rewards: np.ndarray = np.full_like(self.scores, 0.5)
- vali_uids = [
- uid for uid in range(len(scattered_rewards)) if
- self.metagraph.validator_permit[uid] and
- self.metagraph.S[uid] > self.config.neuron.vpermit_tao_limit
- ]
- no_response_uids = [
- uid for uid in range(len(scattered_rewards)) if all([
- self.performance_trackers[m].get_prediction_count(uid) == 0
- for m in ["image", "video"]
- ])
- ]
-
- scattered_rewards[vali_uids] = 0.
- scattered_rewards[no_response_uids] = 0.
- scattered_rewards[uids_array] = rewards
- bt.logging.debug(f"Scattered rewards: {rewards}")
-
- # Update scores with rewards produced by this step.
- # shape: [ metagraph.n ]
- alpha: float = self.config.neuron.moving_average_alpha
- self.scores: np.ndarray = alpha * scattered_rewards + (1 - alpha) * self.scores
- bt.logging.debug(f"Updated moving avg scores: {self.scores}")
-
- def save_miner_history(self):
- bt.logging.debug(f"Saving miner performance history to {self.image_history_cache_path}")
- joblib.dump(self.performance_trackers['image'], self.image_history_cache_path)
- bt.logging.debug(f"Saving miner performance history to {self.video_history_cache_path}")
- joblib.dump(self.performance_trackers['video'], self.video_history_cache_path)
-
- def load_miner_history(self):
- def convert_v1_to_v2(tracker):
- """Convert a v1 tracker to v2 format"""
- new_tracker = MinerPerformanceTracker(tracker.store_last_n_predictions)
-
- # copy hotkeys, transform predictions from float to vector
- new_tracker.miner_hotkeys = tracker.miner_hotkeys.copy()
- for uid in tracker.prediction_history:
- new_predictions = deque(maxlen=tracker.store_last_n_predictions)
- new_labels = deque(maxlen=tracker.store_last_n_predictions)
-
- for pred, label in zip(tracker.prediction_history[uid], tracker.label_history[uid]):
- new_labels.append(label)
- if isinstance(pred, float):
- if pred != -1:
- # convert old binary prediction to probability vector [p_real, p_synthetic, p_semi]
- new_predictions.append(np.array([1 - pred, pred, 0.0]))
- else:
- new_predictions.append(np.array([-1., -1., -1.]))
- elif isinstance(pred, np.ndarray):
- new_predictions.append(pred)
- else:
- raise ValueError(f"Invalid prediction type encountered while loading history: {pred}")
-
- new_tracker.prediction_history[uid] = new_predictions
- new_tracker.label_history[uid] = new_labels
- return new_tracker
-
- def load(path):
- if os.path.exists(path):
- bt.logging.info(f"Loading miner performance history from {path}")
- try:
- tracker = joblib.load(path)
- if not hasattr(tracker, 'version'):
- bt.logging.info(f"Converting performance tracker from v1 to v2 format")
- tracker = convert_v1_to_v2(tracker)
-
- num_miners_history = len([
- uid for uid in tracker.prediction_history
- if len([p for p in tracker.prediction_history[uid] if not np.array_equal(p, -1)]) > 0
- ])
- bt.logging.info(f"Loaded history for {num_miners_history} miners")
-
- except Exception as e:
- bt.logging.error(f'Error loading miner performance tracker: {e}')
- tracker = MinerPerformanceTracker()
- else:
- bt.logging.info(f"No miner performance history found at {path} - starting fresh!")
- tracker = MinerPerformanceTracker()
- return tracker
-
- self.performance_trackers['image'] = load(self.image_history_cache_path)
- self.performance_trackers['video'] = load(self.video_history_cache_path)
-
- def save_state(self):
- """Saves the state of the validator to a file."""
- bt.logging.info("Saving validator state.")
-
- # Save the state of the validator to file.
- np.savez(
- os.path.join(self.config.neuron.full_path, "state.npz"),
- step=self.step,
- scores=self.scores,
- hotkeys=self.hotkeys,
- )
- self.save_miner_history()
-
- def load_state(self):
- """Loads the state of the validator from a file."""
- bt.logging.info("Loading validator state.")
- state_path = os.path.join(self.config.neuron.full_path, "state.npz")
- # Load the state of the validator from file.
- if os.path.exists(state_path):
- state = np.load(state_path)
- self.step = state["step"]
- self.scores = state["scores"]
- self.hotkeys = state["hotkeys"]
- else:
- bt.logging.warning(f"Warning: no state file available at {state_path}")
- self.load_miner_history()
diff --git a/bitmind/cache/__init__.py b/bitmind/cache/__init__.py
new file mode 100644
index 00000000..b39d2e3f
--- /dev/null
+++ b/bitmind/cache/__init__.py
@@ -0,0 +1 @@
+from .cache_system import CacheSystem
diff --git a/bitmind/cache/cache_fs.py b/bitmind/cache/cache_fs.py
new file mode 100644
index 00000000..59a438b8
--- /dev/null
+++ b/bitmind/cache/cache_fs.py
@@ -0,0 +1,399 @@
+from pathlib import Path
+from typing import Dict, List, Optional, Union, Any
+import time
+import bittensor as bt
+
+from bitmind.types import CacheConfig, CacheType
+from bitmind.cache.util.filesystem import (
+ analyze_directory,
+ scale_size,
+ format_size,
+ print_directory_tree,
+ is_source_complete,
+)
+
+
+class CacheFS:
+ def __init__(self, config: CacheConfig):
+ self.config = config
+
+ self.cache_dir = config.get_path()
+ self.compressed_dir = self.cache_dir / "sources"
+ self.compressed_dir.mkdir(exist_ok=True, parents=True)
+
+ self._log_prefix = f"[{config.modality}:{config.media_type}]"
+
+ self._file_index = {}
+ self._index_timestamp = {}
+ self._index_ttl = 60
+
+ def set_index_ttl(self, seconds: int) -> None:
+ """Set the time-to-live for the file index in seconds."""
+ self._index_ttl = max(0, seconds)
+ self._log_debug(f"File index TTL set to {self._index_ttl} seconds")
+
+ def invalidate_index(
+ self, cache_type: Optional[Union[CacheType, str]] = None
+ ) -> None:
+ """Invalidate the file index for the specified cache type, or all indexes if None."""
+ if cache_type is None:
+ self._file_index = {}
+ self._index_timestamp = {}
+ self._log_info("All file indexes invalidated")
+ else:
+ if isinstance(cache_type, str):
+ cache_type = CacheType(cache_type.lower())
+
+ key = str(cache_type)
+ if key in self._file_index:
+ del self._file_index[key]
+ if key in self._index_timestamp:
+ del self._index_timestamp[key]
+
+ def _is_index_valid(self, cache_type: Union[CacheType, str]) -> bool:
+ """Check if the index for the given cache type is still valid based on TTL."""
+ if isinstance(cache_type, str):
+ cache_type = CacheType(cache_type.lower())
+
+ key = str(cache_type)
+ if self._index_ttl <= 0 or key not in self._index_timestamp:
+ return False
+
+ return time.time() - self._index_timestamp[key] < self._index_ttl
+
+ def num_files(
+ self,
+ cache_type: Union[CacheType, str] = CacheType.MEDIA,
+ file_extensions: Optional[List[str]] = None,
+ use_index: bool = True,
+ ) -> int:
+ """Returns the number of files of the given type and extensions."""
+ files = self.get_files(
+ cache_type=cache_type, file_extensions=file_extensions, use_index=use_index
+ )
+ return len(files)
+
+ def get_files(
+ self,
+ cache_type: Union[CacheType, str] = CacheType.MEDIA,
+ file_extensions: Optional[List[str]] = None,
+ group_by_source: bool = False,
+ use_index: bool = True,
+ ) -> Union[List[Path], Dict[str, List[Path]]]:
+ """
+ Get files of the specified type with the given extensions.
+
+ Args:
+ cache_type: Type of cache to search (Media or Compressed)
+ file_extensions: List of file extensions to filter by (e.g., ['.jpg', '.png'])
+ group_by_source: Whether to group files by their source directory
+ use_index: Whether to use indexed file list if available
+
+ Returns:
+ Either a list of file paths or a dictionary mapping source directories to lists of files
+ """
+ if isinstance(cache_type, str):
+ cache_type = CacheType(cache_type.lower())
+
+ if file_extensions is not None:
+ file_extensions = set([ext.lower() for ext in file_extensions])
+
+ key = str(cache_type)
+ if use_index and self._is_index_valid(cache_type) and key in self._file_index:
+ files = self._file_index[key]
+ if group_by_source:
+ return self._group_files_by_source(files, cache_type)
+ return files
+
+ if cache_type == CacheType.MEDIA:
+ base_dir = self.cache_dir
+ elif cache_type == CacheType.COMPRESSED:
+ base_dir = self.compressed_dir
+
+ files = []
+ if base_dir.exists():
+ dataset_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
+ for dataset_dir in dataset_dirs:
+ for file in dataset_dir.iterdir():
+ extension_match = (
+ file_extensions is None
+ or file.suffix.lower() in file_extensions
+ )
+ if file.is_file() and extension_match:
+ files.append(file)
+
+ self._file_index[key] = files
+ self._index_timestamp[key] = time.time()
+
+ if group_by_source:
+ return self._group_files_by_source(files, cache_type)
+
+ return files
+
+ def is_empty(self, cache_type: Union[CacheType, str]) -> bool:
+ """
+ Efficiently check if a cache directory is empty.
+
+ Args:
+ cache_type: Type of cache to check (Media or Compressed)
+
+ Returns:
+ bool: True if the cache is empty, False if it contains any files
+ """
+ if isinstance(cache_type, str):
+ cache_type = CacheType(cache_type.lower())
+
+ base_dir = (
+ self.cache_dir if cache_type == CacheType.MEDIA else self.compressed_dir
+ )
+
+ if not base_dir.exists():
+ return True
+
+ try:
+ # Check if there are any dataset directories
+ dataset_dir = next(
+ d for d in base_dir.iterdir() if d.is_dir() and d != self.compressed_dir
+ )
+
+ # Check if any dataset directory contains files
+ try:
+ next(
+ f
+ for f in dataset_dir.iterdir()
+ if f.is_file()
+ and (cache_type == CacheType.MEDIA or is_source_complete(f))
+ )
+ return False
+ except StopIteration:
+ return True
+ except StopIteration:
+ return True
+
+ def _group_files_by_source(
+ self, files: List[Path], cache_type: CacheType
+ ) -> Dict[str, List[Path]]:
+ """Helper method to group files by their source directory.
+ TODO make this cache_type agnostic
+ """
+ if cache_type == CacheType.MEDIA:
+ base_dir = self.cache_dir
+ else:
+ base_dir = self.compressed_dir
+
+ result = {}
+ for file in files:
+ if file.exists():
+ try:
+ rel_path = file.relative_to(base_dir)
+ subdir = str(rel_path.parent)
+ except ValueError:
+ subdir = str(file.parent)
+
+ if subdir not in result:
+ result[subdir] = []
+ result[subdir].append(file)
+ return result
+
+ async def maybe_prune_cache(
+ self,
+ cache_type: Union[CacheType, str],
+ file_extensions: Optional[List[str]],
+ ) -> None:
+ """
+ Prune the cache if it exceeds the configured size limit.
+
+ Args:
+ cache_type: Type of cache to prune (Media or Compressed)
+ file_extensions: List of file extensions to consider for pruning
+ """
+ if isinstance(cache_type, str):
+ cache_type = CacheType(cache_type.lower())
+
+ if cache_type == CacheType.COMPRESSED:
+ max_gb = self.config.max_compressed_gb
+ elif cache_type == CacheType.MEDIA:
+ max_gb = self.config.max_media_gb
+
+ max_bytes = scale_size(max_gb, "GB", "B")
+ current_bytes = self.cache_size(cache_type, file_extensions, unit="B")
+ num_files = self.num_files(cache_type, file_extensions)
+ self._log_info(
+ f"Pruning Check | {cache_type} cache | {num_files} files | {format_size(current_bytes, 'B', 'GB')}"
+ )
+ if current_bytes <= max_bytes:
+ return
+
+ files = self.get_files(
+ cache_type=cache_type, file_extensions=file_extensions, use_index=True
+ )
+
+ files_dict = self._group_files_by_source(files, cache_type)
+
+ for subdir in files_dict:
+ files_dict[subdir] = sorted(
+ files_dict[subdir],
+ key=lambda f: f.stat().st_mtime if f.exists() else float("inf"),
+ )
+
+ self._log_info(f"Pruning cache to stay under {max_gb} GB...")
+
+ n_removed = 0
+ bytes_removed = 0
+ remaining_bytes = current_bytes
+
+ key = str(cache_type)
+ has_index = key in self._file_index
+
+ while remaining_bytes > max_bytes and any(
+ files for files in files_dict.values()
+ ):
+ largest_subdir = max(
+ [subdir for subdir, files in files_dict.items() if files],
+ key=lambda subdir: len(files_dict[subdir]),
+ default=None,
+ )
+
+ if largest_subdir is None:
+ break
+
+ file = files_dict[largest_subdir].pop(0)
+ try:
+ if file.exists():
+ file_size = file.stat().st_size
+ file.unlink()
+ if has_index:
+ try:
+ self._file_index[key].remove(file)
+ except ValueError:
+ pass
+
+ meta_file = file.with_suffix(".json")
+ if meta_file.exists():
+ meta_file.unlink()
+ if has_index:
+ try:
+ self._file_index[key].remove(meta_file)
+ except ValueError:
+ pass
+
+ n_removed += 1
+ bytes_removed += file_size
+ remaining_bytes -= file_size
+ except Exception as e:
+ self._log_error(f"Error removing file {file}: {e}")
+
+ removed_gb_str = format_size(bytes_removed, "B", "GB")
+ new_gb_str = self.cache_size(cache_type, file_extensions, "GB", as_str=True)
+ self._log_info(
+ f"Removed: {n_removed} files; {removed_gb_str} | New size: {new_gb_str}"
+ )
+
+ def cache_size(
+ self,
+ cache_type: CacheType,
+ file_extensions: Optional[List[str]] = None,
+ unit: str = "GB",
+ as_str: bool = False,
+ use_index: bool = True,
+ ) -> Union[str, float]:
+ """
+ Returns size of media or compressed cache.
+
+ Args:
+ cache_type: Type of cache to measure (Media or Compressed)
+ file_extensions: List of file extensions to filter by
+ unit: Unit to return the size in (e.g., 'B', 'KB', 'MB', 'GB')
+ as_str: Whether to return the size as a formatted string
+ use_index: Whether to use indexed file list if available
+
+ Returns:
+ Size of the cache, either as a float or a formatted string with units
+ """
+ files = self.get_files(
+ cache_type=cache_type, file_extensions=file_extensions, use_index=use_index
+ )
+ total_bytes = sum(f.stat().st_size for f in files if f.exists())
+ if as_str:
+ return format_size(total_bytes, "B", unit)
+ return scale_size(total_bytes, "B", unit)
+
+ def get_cache_stats(self, use_index: bool = True) -> Dict[str, Any]:
+ """
+ Get statistics about the cache.
+
+ Args:
+ use_index: Whether to use indexed file list if available
+
+ Returns:
+ Dictionary with cache statistics including file counts and sizes
+ """
+ media_files = self.get_files(CacheType.MEDIA, use_index=use_index)
+ media_count = len(media_files)
+ media_bytes = sum(f.stat().st_size for f in media_files if f.exists())
+ media_gb = scale_size(media_bytes, "B", "GB")
+
+ compressed_files = self.get_files(CacheType.COMPRESSED, use_index=use_index)
+ compressed_count = len(compressed_files)
+ compressed_bytes = sum(f.stat().st_size for f in compressed_files if f.exists())
+ compressed_gb = scale_size(compressed_bytes, "B", "GB")
+
+ return {
+ "cache_dir": str(self.cache_dir),
+ "modality": self.config.modality,
+ "media_type": self.config.media_type,
+ "media_count": media_count,
+ "media_bytes": media_bytes,
+ "media_gb": media_gb,
+ "compressed_count": compressed_count,
+ "compressed_bytes": compressed_bytes,
+ "compressed_gb": compressed_gb,
+ "total_count": media_count + compressed_count,
+ "total_bytes": media_bytes + compressed_bytes,
+ "total_gb": media_gb + compressed_gb,
+ }
+
+ def print_directory_tree(
+ self, min_file_count: int = 1, include_sources: bool = True
+ ):
+ """Print a tree representation of the cache directory structure."""
+ exclude_dirs = [] if include_sources else ["sources"]
+
+ self._log_info(f"Analyzing cache directory structure: {self.cache_dir}")
+ tree_data = analyze_directory(
+ self.cache_dir,
+ exclude_dirs=exclude_dirs,
+ min_file_count=min_file_count,
+ log_func=self._log_info,
+ )
+
+ self._log_info(f"\n{self.cache_dir}")
+ self._log_info(f"Modality: {self.config.modality}")
+ self._log_info(f"Media Type: {self.config.media_type}")
+ self._log_info(f"Total Size: {format_size(tree_data['size'])}")
+ self._log_info(f"Total Files: {tree_data['count']}")
+ if not include_sources:
+ self._log_info(
+ "Note: Source directories are excluded from the visualization"
+ )
+ self._log_info("-" * 80)
+
+ print_directory_tree(tree_data, "", True, "", self._log_info)
+
+ def _log_info(self, message: str) -> None:
+ bt.logging.info(f"{self._log_prefix} {message}")
+
+ def _log_warning(self, message: str) -> None:
+ bt.logging.warning(f"{self._log_prefix} {message}")
+
+ def _log_error(self, message: str) -> None:
+ bt.logging.error(f"{self._log_prefix} {message}")
+
+ def _log_debug(self, message: str) -> None:
+ bt.logging.debug(f"{self._log_prefix} {message}")
+
+ def _log_success(self, message: str) -> None:
+ bt.logging.debug(f"{self._log_prefix} {message}")
+
+ def _log_trace(self, message: str) -> None:
+ bt.logging.trace(f"{self._log_prefix} {message}")
diff --git a/bitmind/cache/cache_system.py b/bitmind/cache/cache_system.py
new file mode 100644
index 00000000..2c4efa97
--- /dev/null
+++ b/bitmind/cache/cache_system.py
@@ -0,0 +1,244 @@
+from typing import Any, Dict, List, Optional, Type
+import traceback
+
+import asyncio
+import bittensor as bt
+
+from bitmind.types import CacheUpdaterConfig, CacheConfig, Modality, MediaType
+from bitmind.cache.datasets import DatasetRegistry, initialize_dataset_registry
+from bitmind.cache.updater import (
+ BaseUpdater,
+ UpdaterRegistry,
+ ImageUpdater,
+ VideoUpdater,
+)
+from bitmind.cache.sampler import (
+ BaseSampler,
+ SamplerRegistry,
+ ImageSampler,
+ VideoSampler,
+)
+
+
+class CacheSystem:
+ """
+ Main facade for the caching system.
+ """
+
+ def __init__(self):
+ self.dataset_registry = DatasetRegistry()
+ self.updater_registry = UpdaterRegistry()
+ self.sampler_registry = SamplerRegistry()
+
+ async def initialize(
+ self,
+ base_dir,
+ max_compressed_gb,
+ max_media_gb,
+ media_files_per_source,
+ ):
+ try:
+ dataset_registry = initialize_dataset_registry()
+ for dataset in dataset_registry.datasets:
+ self.register_dataset(dataset)
+
+ for modality in Modality:
+ for media_type in MediaType:
+ cache_config = CacheConfig(
+ base_dir=base_dir,
+ modality=modality.value,
+ media_type=media_type.value,
+ max_compressed_gb=max_compressed_gb,
+ max_media_gb=max_media_gb,
+ )
+ sampler_class = (
+ ImageSampler if modality == Modality.IMAGE else VideoSampler
+ )
+ self.create_sampler(
+ name=f"{media_type.value}_{modality.value}_sampler",
+ sampler_class=sampler_class,
+ cache_config=cache_config,
+ )
+
+ # synthetic video updater not currently used, only generate locally
+ if not (
+ modality == Modality.VIDEO and media_type == MediaType.SYNTHETIC
+ ):
+ updater_config = CacheUpdaterConfig(
+ num_sources_per_dataset=1, # one compressed source per dataset for initialization
+ num_items_per_source=media_files_per_source,
+ )
+ updater_class = (
+ ImageUpdater if modality == Modality.IMAGE else VideoUpdater
+ )
+ self.create_updater(
+ name=f"{media_type.value}_{modality.value}_updater",
+ updater_class=updater_class,
+ cache_config=cache_config,
+ updater_config=updater_config,
+ )
+
+ # Initialize caches (populate if empty)
+ bt.logging.info("Starting initial cache population")
+ await self.initialize_caches()
+ bt.logging.info("Initial cache population complete")
+
+ except Exception as e:
+ bt.logging.error(f"Error initializing caches: {e}")
+ bt.logging.error(traceback.format_exc())
+
+ def register_dataset(self, dataset) -> None:
+ """
+ Register a dataset with the system.
+
+ Args:
+ dataset: Dataset configuration to register
+ """
+ self.dataset_registry.register(dataset)
+
+ def register_datasets(self, datasets: List[Any]) -> None:
+ """
+ Register multiple datasets with the system.
+
+ Args:
+ datasets: List of dataset configurations to register
+ """
+ self.dataset_registry.register_all(datasets)
+
+ def create_updater(
+ self,
+ name: str,
+ updater_class: Type[BaseUpdater],
+ cache_config: CacheConfig,
+ updater_config: CacheUpdaterConfig,
+ ) -> BaseUpdater:
+ """
+ Create and register an updater.
+
+ Args:
+ name: Unique name for the updater
+ updater_class: Updater class to instantiate
+ cache_config: Cache configuration
+ updater_config: Updater configuration
+
+ Returns:
+ The created updater instance
+ """
+ updater = updater_class(
+ cache_config=cache_config,
+ updater_config=updater_config,
+ data_manager=self.dataset_registry,
+ )
+ self.updater_registry.register(name, updater)
+ return updater
+
+ def create_sampler(
+ self, name: str, sampler_class: Type[BaseSampler], cache_config: CacheConfig
+ ) -> BaseSampler:
+ """
+ Create and register a sampler.
+
+ Args:
+ name: Unique name for the sampler
+ sampler_class: Sampler class to instantiate
+ cache_config: Cache configuration
+
+ Returns:
+ The created sampler instance
+ """
+ sampler = sampler_class(cache_config=cache_config)
+ self.sampler_registry.register(name, sampler)
+ return sampler
+
+ async def initialize_caches(self) -> None:
+ """
+ Initialize all caches to ensure they have content.
+ This is typically called during system startup.
+ """
+ updaters = self.updater_registry.get_all()
+ names = [name for name, _ in updaters.items()]
+ bt.logging.debug(f"Initializing {len(updaters)} caches: {names}")
+
+ cache_init_tasks = []
+ for name, updater in updaters.items():
+ cache_init_tasks.append(updater.initialize_cache())
+
+ if cache_init_tasks:
+ await asyncio.gather(*cache_init_tasks)
+
+ async def update_compressed_caches(self) -> None:
+ """
+ Update all compressed caches in parallel
+ This is typically called from a block callback.
+ """
+ updaters = self.updater_registry.get_all()
+ names = [name for name, _ in updaters.items()]
+ bt.logging.trace(f"Updating {len(updaters)} compressed caches: {names}")
+
+ tasks = []
+ for name, updater in updaters.items():
+ tasks.append(updater.update_compressed_cache())
+
+ if tasks:
+ await asyncio.gather(*tasks)
+
+ async def update_media_caches(self) -> None:
+ """
+ Update all media caches in parallel.
+ This is typically called from a block callback.
+ """
+ updaters = self.updater_registry.get_all()
+ names = [name for name, _ in updaters.items()]
+ bt.logging.debug(f"Updating {len(updaters)} media caches: {names}")
+
+ tasks = []
+ for name, updater in updaters.items():
+ tasks.append(updater.update_media_cache())
+
+ if tasks:
+ await asyncio.gather(*tasks)
+
+ async def sample(self, name: str, count: int, **kwargs) -> Optional[Dict[str, Any]]:
+ """
+ Sample from a specific sampler.
+
+ Args:
+ name: Name of the sampler to use
+ count: Number of items to sample
+
+ Returns:
+ The sampled items or None if sampler not found
+ """
+ return await self.sampler_registry.sample(name, count, **kwargs)
+
+ async def sample_all(self, count: int = 1) -> Dict[str, Dict[str, Any]]:
+ """
+ Sample from all samplers.
+
+ Args:
+ count: Number of items to sample from each sampler
+
+ Returns:
+ Dictionary mapping sampler names to their samples
+ """
+ return await self.sampler_registry.sample_all(count)
+
+ @property
+ def samplers(self):
+ """
+ Get all registered samplers.
+
+ Returns:
+ Dictionary of sampler names to sampler instances
+ """
+ return self.sampler_registry.get_all()
+
+ @property
+ def updaters(self):
+ """
+ Get all registered updaters.
+
+ Returns:
+ Dictionary of updater names to updater instances
+ """
+ return self.updater_registry.get_all()
diff --git a/bitmind/cache/datasets/__init__.py b/bitmind/cache/datasets/__init__.py
new file mode 100644
index 00000000..54bdc80e
--- /dev/null
+++ b/bitmind/cache/datasets/__init__.py
@@ -0,0 +1,2 @@
+from .datasets import initialize_dataset_registry
+from .dataset_registry import DatasetRegistry
diff --git a/bitmind/cache/datasets/dataset_registry.py b/bitmind/cache/datasets/dataset_registry.py
new file mode 100644
index 00000000..ddfa2420
--- /dev/null
+++ b/bitmind/cache/datasets/dataset_registry.py
@@ -0,0 +1,108 @@
+from typing import List, Optional
+
+from bitmind.types import DatasetConfig, MediaType, Modality
+
+
+class DatasetRegistry:
+ """
+ Registry for dataset configurations with filtering capabilities.
+ """
+
+ def __init__(self):
+ self.datasets: List[DatasetConfig] = []
+
+ def register(self, dataset: DatasetConfig) -> None:
+ """
+ Register a dataset with the system.
+
+ Args:
+ dataset: Dataset configuration to register
+ """
+ self.datasets.append(dataset)
+
+ def register_all(self, datasets: List[DatasetConfig]) -> None:
+ """
+ Register multiple datasets with the system.
+
+ Args:
+ datasets: List of dataset configurations to register
+ """
+ for dataset in datasets:
+ self.register(dataset)
+
+ def get_datasets(
+ self,
+ modality: Optional[Modality] = None,
+ media_type: Optional[MediaType] = None,
+ tags: Optional[List[str]] = None,
+ exclude_tags: Optional[List[str]] = None,
+ enabled_only: bool = True,
+ ) -> List[DatasetConfig]:
+ """
+ Get datasets filtered by type, media_type, and/or tags.
+
+ Args:
+ modality: Filter by dataset type
+ media_type: Filter by media_type
+ tags: Filter by tags (dataset must have ALL specified tags)
+ enabled_only: Only return enabled datasets
+
+ Returns:
+ List of matching datasets
+ """
+ result = self.datasets
+
+ if enabled_only:
+ result = [d for d in result if d.enabled]
+
+ if modality:
+ if isinstance(modality, str):
+ modality = Modality(modality.lower())
+ result = [d for d in result if d.type == modality]
+
+ if media_type:
+ if isinstance(media_type, str):
+ media_type = MediaType(media_type.lower())
+ result = [d for d in result if d.media_type == media_type]
+
+ if tags:
+ result = [d for d in result if all(tag in d.tags for tag in tags)]
+
+ if exclude_tags:
+ result = [
+ d for d in result if all(tag not in d.tags for tag in exclude_tags)
+ ]
+
+ return result
+
+ def enable_dataset(self, path: str, enabled: bool = True) -> bool:
+ """
+ Enable or disable a dataset by path.
+
+ Args:
+ path: Dataset path to enable/disable
+ enabled: Whether to enable or disable
+
+ Returns:
+ True if successful, False if dataset not found
+ """
+ for dataset in self.datasets:
+ if dataset.path == path:
+ dataset.enabled = enabled
+ return True
+ return False
+
+ def get_dataset_by_path(self, path: str) -> Optional[DatasetConfig]:
+ """
+ Get a dataset by its path.
+
+ Args:
+ path: Dataset path to find
+
+ Returns:
+ Dataset config or None if not found
+ """
+ for dataset in self.datasets:
+ if dataset.path == path:
+ return dataset
+ return None
diff --git a/bitmind/cache/datasets/datasets.py b/bitmind/cache/datasets/datasets.py
new file mode 100644
index 00000000..31172b49
--- /dev/null
+++ b/bitmind/cache/datasets/datasets.py
@@ -0,0 +1,165 @@
+"""
+Dataset definitions for the validator cache system
+"""
+
+from typing import List
+
+from bitmind.types import Modality, MediaType, DatasetConfig
+
+
+def get_image_datasets() -> List[DatasetConfig]:
+ """
+ Get the list of image datasets used by the validator.
+
+ Returns:
+ List of image dataset configurations
+ """
+ return [
+ # Real image datasets
+ DatasetConfig(
+ path="bitmind/bm-eidon-image",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["frontier"],
+ ),
+ DatasetConfig(
+ path="bitmind/bm-real",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ ),
+ DatasetConfig(
+ path="bitmind/open-image-v7-256",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["diverse"],
+ ),
+ DatasetConfig(
+ path="bitmind/celeb-a-hq",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["faces", "high-quality"],
+ ),
+ DatasetConfig(
+ path="bitmind/ffhq-256",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["faces", "high-quality"],
+ ),
+ DatasetConfig(
+ path="bitmind/MS-COCO-unique-256",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["diverse"],
+ ),
+ DatasetConfig(
+ path="bitmind/AFHQ",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["animals", "high-quality"],
+ ),
+ DatasetConfig(
+ path="bitmind/lfw",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["faces"],
+ ),
+ DatasetConfig(
+ path="bitmind/caltech-256",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["objects", "categorized"],
+ ),
+ DatasetConfig(
+ path="bitmind/caltech-101",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["objects", "categorized"],
+ ),
+ DatasetConfig(
+ path="bitmind/dtd",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["textures"],
+ ),
+ DatasetConfig(
+ path="bitmind/idoc-mugshots-images",
+ type=Modality.IMAGE,
+ media_type=MediaType.REAL,
+ tags=["faces"],
+ ),
+ # Synthetic image datasets
+ DatasetConfig(
+ path="bitmind/JourneyDB",
+ type=Modality.IMAGE,
+ media_type=MediaType.SYNTHETIC,
+ tags=["midjourney"],
+ ),
+ DatasetConfig(
+ path="bitmind/GenImage_MidJourney",
+ type=Modality.IMAGE,
+ media_type=MediaType.SYNTHETIC,
+ tags=["midjourney"],
+ ),
+ # Semisynthetic image datasets
+ DatasetConfig(
+ path="bitmind/face-swap",
+ type=Modality.IMAGE,
+ media_type=MediaType.SEMISYNTHETIC,
+ tags=["faces", "manipulated"],
+ ),
+ ]
+
+
+def get_video_datasets() -> List[DatasetConfig]:
+ """
+ Get the list of video datasets used by the validator.
+ """
+ return [
+ # Real video datasets
+ DatasetConfig(
+ path="bitmind/bm-eidon-video",
+ type=Modality.VIDEO,
+ media_type=MediaType.REAL,
+ tags=["frontier"],
+ compressed_format="zip",
+ ),
+ DatasetConfig(
+ path="shangxd/imagenet-vidvrd",
+ type=Modality.VIDEO,
+ media_type=MediaType.REAL,
+ tags=["diverse"],
+ compressed_format="zip",
+ ),
+ DatasetConfig(
+ path="nkp37/OpenVid-1M",
+ type=Modality.VIDEO,
+ media_type=MediaType.REAL,
+ tags=["diverse", "large-zips"],
+ compressed_format="zip",
+ ),
+ # Semisynthetic video datasets
+ DatasetConfig(
+ path="bitmind/semisynthetic-video",
+ type=Modality.VIDEO,
+ media_type=MediaType.SEMISYNTHETIC,
+ tags=["faces"],
+ compressed_format="zip",
+ ),
+ ]
+
+
+def initialize_dataset_registry():
+ """
+ Initialize and populate the dataset registry.
+
+ Returns:
+ Fully populated DatasetRegistry instance
+ """
+ from bitmind.cache.datasets.dataset_registry import DatasetRegistry
+
+ registry = DatasetRegistry()
+
+ registry.register_all(get_image_datasets())
+ registry.register_all(get_video_datasets())
+
+ return registry
diff --git a/bitmind/cache/sampler/__init__.py b/bitmind/cache/sampler/__init__.py
new file mode 100644
index 00000000..70e88c4e
--- /dev/null
+++ b/bitmind/cache/sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base import BaseSampler
+from .image_sampler import ImageSampler
+from .video_sampler import VideoSampler
+from .sampler_registry import SamplerRegistry
diff --git a/bitmind/cache/sampler/base.py b/bitmind/cache/sampler/base.py
new file mode 100644
index 00000000..c168df76
--- /dev/null
+++ b/bitmind/cache/sampler/base.py
@@ -0,0 +1,47 @@
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, Dict, List
+
+
+from bitmind.cache.cache_fs import CacheFS
+from bitmind.types import CacheConfig
+
+
+class BaseSampler(ABC):
+ """
+ Base class for samplers that provide access to cached media.
+ """
+
+ def __init__(self, cache_config: CacheConfig):
+ self.cache_fs = CacheFS(cache_config)
+
+ @property
+ @abstractmethod
+ def media_file_extensions(self) -> List[str]:
+ """List of file extensions supported by this sampler"""
+ pass
+
+ @abstractmethod
+ async def sample(self, count: int) -> Dict[str, Any]:
+ """
+ Sample items from the media cache.
+
+ Args:
+ count: Number of items to sample
+
+ Returns:
+ Dictionary with sampled items information
+ """
+ pass
+
+ def get_available_files(self, use_index=True) -> List[Path]:
+ """Get list of available files in the media cache"""
+ return self.cache_fs.get_files(
+ cache_type="media",
+ file_extensions=self.media_file_extensions,
+ use_index=use_index,
+ )
+
+ def get_available_count(self, use_index=True) -> int:
+ """Get count of available files in the media cache"""
+ return len(self.get_available_files(use_index))
diff --git a/bitmind/cache/sampler/image_sampler.py b/bitmind/cache/sampler/image_sampler.py
new file mode 100644
index 00000000..587f9217
--- /dev/null
+++ b/bitmind/cache/sampler/image_sampler.py
@@ -0,0 +1,137 @@
+import json
+import random
+import io
+from pathlib import Path
+from typing import Dict, List, Any
+
+import cv2
+import numpy as np
+
+from bitmind.cache.sampler.base import BaseSampler
+from bitmind.cache.cache_fs import CacheConfig
+
+
+class ImageSampler(BaseSampler):
+ """
+ Sampler for cached image data.
+
+ This class provides access to images in the media cache,
+ allowing sampling with or without metadata.
+ """
+
+ def __init__(self, cache_config: CacheConfig):
+ super().__init__(cache_config)
+
+ @property
+ def media_file_extensions(self) -> List[str]:
+ """List of file extensions supported by this sampler"""
+ return [".jpg", ".jpeg", ".png", ".webp"]
+
+ async def sample(
+ self,
+ count: int = 1,
+ remove_from_cache: bool = False,
+ as_float32: bool = False,
+ channels_first: bool = False,
+ as_rgb: bool = True,
+ ) -> Dict[str, Any]:
+ """
+ Sample random images and their metadata from the cache.
+
+ Args:
+ count: Number of images to sample
+ remove_from_cache: Whether to remove sampled images from cache
+
+ Returns:
+ Dictionary containing:
+ - count: Number of images successfully sampled
+ - items: List of dictionaries containing:
+ - image: Image as numpy array in BGR format with shape (H, W, C)
+ - path: Path to the image file
+ - dataset: Source dataset name (if available)
+ - metadata: Additional metadata
+ """
+ cached_files = self.cache_fs.get_files(
+ cache_type="media",
+ file_extensions=self.media_file_extensions,
+ group_by_source=True,
+ )
+
+ if not cached_files:
+ self.cache_fs._log_warning("No images available in cache")
+ return {"count": 0, "items": []}
+
+ sampled_items = []
+
+ attempts = 0
+ max_attempts = count * 3
+
+ while len(sampled_items) < count and attempts < max_attempts:
+ attempts += 1
+
+ source = random.choice(list(cached_files.keys()))
+ if not cached_files[source]:
+ del cached_files[source]
+ if not cached_files:
+ break
+ continue
+
+ image_path = random.choice(cached_files[source])
+
+ try:
+ # Read image directly as numpy array using cv2
+ image = cv2.imread(str(image_path))
+ if image is None:
+ raise ValueError(f"Failed to load image {image_path}")
+
+ if as_float32: # else np.uint8
+ image = image.astype(np.float32) / 255.0
+
+ if as_rgb: # else bgr
+ image = image[:, :, [2, 1, 0]]
+
+ if channels_first: # else channels last
+ image = np.transpose(image, (2, 0, 1))
+
+ metadata = {}
+ metadata_path = image_path.with_suffix(".json")
+ if metadata_path.exists():
+ try:
+ with open(metadata_path, "r") as f:
+ metadata = json.load(f)
+ except Exception as e:
+ self.cache_fs._log_warning(
+ f"Error loading metadata for {image_path}: {e}"
+ )
+
+ item = {
+ "image": image,
+ "path": str(image_path),
+ "metadata_path": str(metadata_path),
+ "metadata": metadata,
+ }
+
+ if "source_parquet" in metadata:
+ item["source"] = metadata["source_parquet"]
+
+ if "original_index" in metadata:
+ item["index"] = metadata["original_index"]
+
+ sampled_items.append(item)
+
+ if remove_from_cache:
+ try:
+ image_path.unlink(missing_ok=True)
+ metadata_path.unlink(missing_ok=True)
+ cached_files[source].remove(image_path)
+ except Exception as e:
+ self.cache_fs._log_warning(
+ f"Failed to remove {image_path}: {e}"
+ )
+
+ except Exception as e:
+ self.cache_fs._log_warning(f"Failed to load image {image_path}: {e}")
+ cached_files[source].remove(image_path)
+ continue
+
+ return {"count": len(sampled_items), "items": sampled_items}
diff --git a/bitmind/cache/sampler/sampler_registry.py b/bitmind/cache/sampler/sampler_registry.py
new file mode 100644
index 00000000..5bdcdc87
--- /dev/null
+++ b/bitmind/cache/sampler/sampler_registry.py
@@ -0,0 +1,62 @@
+from typing import Dict, Optional, Any
+
+import bittensor as bt
+
+from .base import BaseSampler
+
+
+class SamplerRegistry:
+ """
+ Registry for cache samplers.
+ """
+
+ def __init__(self):
+ self._samplers: Dict[str, BaseSampler] = {}
+
+ def register(self, name: str, sampler: BaseSampler) -> None:
+ if name in self._samplers:
+ bt.logging.warning(f"Sampler {name} already registered, will be replaced")
+ self._samplers[name] = sampler
+
+ def get(self, name: str) -> Optional[BaseSampler]:
+ return self._samplers.get(name)
+
+ def get_all(self) -> Dict[str, BaseSampler]:
+ return dict(self._samplers)
+
+ def deregister(self, name: str) -> None:
+ if name in self._samplers:
+ del self._samplers[name]
+
+ async def sample(self, name: str, count: int, **kwargs) -> Optional[Dict[str, Any]]:
+ """
+ Sample from a specific sampler.
+
+ Args:
+ name: Name of the sampler to use
+ count: Number of items to sample
+
+ Returns:
+ The sampled items or None if sampler not found
+ """
+ sampler = self.get(name)
+ if not sampler:
+ bt.logging.error(f"Sampler {name} not found")
+ return None
+
+ return await sampler.sample(count, **kwargs)
+
+ async def sample_all(self, count_per_sampler: int = 1) -> Dict[str, Dict[str, Any]]:
+ """
+ Sample from all samplers.
+
+ Args:
+ count_per_sampler: Number of items to sample from each sampler
+
+ Returns:
+ Dictionary mapping sampler names to their samples
+ """
+ results = {}
+ for name, sampler in self._samplers.items():
+ results[name] = await sampler.sample(count_per_sampler)
+ return results
diff --git a/bitmind/cache/sampler/video_sampler.py b/bitmind/cache/sampler/video_sampler.py
new file mode 100644
index 00000000..90486af7
--- /dev/null
+++ b/bitmind/cache/sampler/video_sampler.py
@@ -0,0 +1,274 @@
+import json
+import os
+import math
+import random
+import tempfile
+
+from pathlib import Path
+from typing import Dict, List, Any, Optional
+from io import BytesIO
+
+import ffmpeg
+import numpy as np
+from PIL import Image
+
+from bitmind.cache.sampler.base import BaseSampler
+from bitmind.cache.cache_fs import CacheConfig
+from bitmind.cache.util.video import get_video_metadata
+
+
+class VideoSampler(BaseSampler):
+ """
+ Sampler for cached video data.
+
+ This class provides access to videos in the media cache,
+ allowing sampling of video segments as binary data.
+ """
+
+ def __init__(self, cache_config: CacheConfig):
+ super().__init__(cache_config)
+
+ @property
+ def media_file_extensions(self) -> List[str]:
+ """List of file extensions supported by this sampler"""
+ return [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".webm"]
+
+ async def sample(
+ self,
+ count: int = 1,
+ remove_from_cache: bool = False,
+ min_duration: float = 1.0,
+ max_duration: float = 6.0,
+ ) -> Dict[str, Any]:
+ """
+ Sample random video segments from the cache.
+
+ Args:
+ count: Number of videos to sample
+ remove_from_cache: Whether to remove sampled videos from cache
+
+ Returns:
+ Dictionary containing:
+ - count: Number of videos successfully sampled
+ - items: List of dictionaries containing video binary data and metadata
+ """
+ cached_files = self.cache_fs.get_files(
+ cache_type="media",
+ file_extensions=self.media_file_extensions,
+ group_by_source=True,
+ )
+
+ if not cached_files:
+ self.cache_fs._log_warning("No videos available in cache")
+ return {"count": 0, "items": []}
+
+ sampled_items = []
+ for _ in range(count):
+ if not cached_files:
+ break
+
+ video_result = await self._sample_frames(
+ files=cached_files,
+ min_duration=min_duration,
+ max_duration=max_duration,
+ remove_from_cache=remove_from_cache,
+ )
+
+ if video_result:
+ sampled_items.append(video_result)
+
+ return {"count": len(sampled_items), "items": sampled_items}
+
+ async def _sample_frames(
+ self,
+ files,
+ min_duration: float = 1.0,
+ max_duration: float = 6.0,
+ max_fps: float = 30.0,
+ remove_from_cache: bool = False,
+ as_float32: bool = False,
+ channels_first: bool = False,
+ as_rgb: bool = True,
+ ) -> Optional[Dict[str, Any]]:
+ """
+ Sample a random video segment and return it as a numpy array.
+
+ Args:
+ min_duration: Minimum duration of video segment to extract in seconds
+ max_duration: Maximum duration of video segment to extract in seconds
+ remove_from_cache: Whether to remove the source video from cache
+ as_float32: Whether to return frames as float32 (0-1) instead of uint8 (0-255)
+ channels_first: Whether to return frames with channels first (TCHW) instead of channels last (THWC)
+ as_rgb: Whether to return frames in RGB format (True) or BGR format (False)
+
+ Returns:
+ Dictionary containing:
+ - frames: Video frames as numpy array with shape (T,H,W,C)
+ - metadata: Video metadata
+ - source: Source information
+ - segment: Information about the extracted segment
+ Or None if sampling fails
+ """
+ for _ in range(5):
+ if not files:
+ self.cache_fs._log_warning("No more videos available to try")
+ return None
+
+ source = random.choice(list(files.keys()))
+ if not files[source]:
+ del files[source]
+ continue
+
+ video_path = random.choice(files[source])
+
+ try:
+ if not video_path.exists():
+ files[source].remove(video_path)
+ continue
+
+ try:
+ video_info = get_video_metadata(str(video_path))
+ total_duration = video_info.get("duration", 0)
+ width = int(video_info.get("width", 256))
+ height = int(video_info.get("height", 256))
+ reported_fps = float(video_info.get("fps", max_fps))
+ except Exception as e:
+ self.cache_fs._log_error(
+ f"Unable to extract video metadata from {str(video_path)}: {e}"
+ )
+ files[source].remove(video_path)
+ continue
+
+ if (
+ reported_fps > max_fps
+ or reported_fps <= 0
+ or not math.isfinite(reported_fps)
+ ):
+ self.cache_fs._log_warning(
+ f"Unreasonable FPS ({reported_fps}) detected in {video_path}, capping at {max_fps}"
+ )
+ frame_rate = max_fps
+ else:
+ frame_rate = reported_fps
+
+ target_duration = random.uniform(min_duration, max_duration)
+ target_duration = min(target_duration, total_duration)
+
+ num_frames = int(target_duration * frame_rate) + 1
+
+ actual_duration = (num_frames - 1) / frame_rate
+
+ max_start = max(0, total_duration - actual_duration)
+ start_time = random.uniform(0, max_start)
+
+ frames = []
+ no_data = []
+
+ for i in range(num_frames):
+ timestamp = start_time + (i / frame_rate)
+
+ try:
+ out_bytes, err = (
+ ffmpeg.input(str(video_path), ss=str(timestamp))
+ .filter("select", "eq(n,0)")
+ .output(
+ "pipe:",
+ vframes=1,
+ format="image2",
+ vcodec="png",
+ loglevel="error",
+ )
+ .run(capture_stdout=True, capture_stderr=True)
+ )
+
+ if not out_bytes:
+ no_data.append(timestamp)
+ continue
+
+ try:
+ frame = Image.open(BytesIO(out_bytes))
+ frame.load() # Verify image can be loaded
+ frames.append(np.array(frame))
+ except Exception as e:
+ self.cache_fs._log_error(
+ f"Failed to process frame at {timestamp}s: {e}"
+ )
+ continue
+
+ except ffmpeg.Error as e:
+ self.cache_fs._log_error(
+ f"FFmpeg error at {timestamp}s: {e.stderr.decode()}"
+ )
+ continue
+
+ if len(no_data) > 0:
+ tmin, tmax = min(no_data), max(no_data)
+ self.cache_fs._log_warning(
+ f"No data received for {len(no_data)} frames between {tmin} and {tmax}"
+ )
+
+ if not frames:
+ self.cache_fs._log_warning(
+ f"No frames successfully extracted from {video_path}"
+ )
+ files[source].remove(video_path)
+ continue
+
+ frames = np.stack(frames, axis=0)
+
+ if as_float32:
+ frames = frames.astype(np.float32) / 255.0
+
+ if not as_rgb:
+ frames = frames[:, :, :, [2, 1, 0]] # RGB to BGR
+
+ if channels_first:
+ frames = np.transpose(frames, (0, 3, 1, 2))
+
+ metadata = {}
+ metadata_path = video_path.with_suffix(".json")
+ if metadata_path.exists():
+ try:
+ with open(metadata_path, "r") as f:
+ metadata = json.load(f)
+ except Exception as e:
+ self.cache_fs._log_warning(
+ f"Error loading metadata for {video_path}: {e}"
+ )
+
+ result = {
+ "video": frames,
+ "path": str(video_path),
+ "metadata_path": str(metadata_path),
+ "metadata": metadata,
+ "segment": {
+ "start_time": start_time,
+ "duration": actual_duration,
+ "fps": frame_rate,
+ "width": width,
+ "height": height,
+ "num_frames": len(frames),
+ },
+ }
+
+ if remove_from_cache:
+ try:
+ video_path.unlink(missing_ok=True)
+ metadata_path.unlink(missing_ok=True)
+ files[source].remove(video_path)
+ except Exception as e:
+ self.cache_fs._log_warning(
+ f"Failed to remove {video_path}: {e}"
+ )
+
+ self.cache_fs._log_info(
+ f"Successfully sampled {actual_duration}s segment from {video_path} ({len(frames)} frames)"
+ )
+ return result
+
+ except Exception as e:
+ self.cache_fs._log_error(f"Error sampling from {video_path}: {e}")
+ files[source].remove(video_path)
+
+ self.cache_fs._log_error("Failed to sample any video after multiple attempts")
+ return None
diff --git a/bitmind/cache/updater/__init__.py b/bitmind/cache/updater/__init__.py
new file mode 100644
index 00000000..09ca26be
--- /dev/null
+++ b/bitmind/cache/updater/__init__.py
@@ -0,0 +1,4 @@
+from .base import BaseUpdater
+from .image_updater import ImageUpdater
+from .video_updater import VideoUpdater
+from .updater_registry import UpdaterRegistry
diff --git a/bitmind/cache/updater/base.py b/bitmind/cache/updater/base.py
new file mode 100644
index 00000000..acc9b71b
--- /dev/null
+++ b/bitmind/cache/updater/base.py
@@ -0,0 +1,270 @@
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, List, Optional
+
+import numpy as np
+
+from bitmind.cache.cache_fs import CacheFS
+from bitmind.cache.datasets import DatasetRegistry
+from bitmind.types import CacheUpdaterConfig, CacheConfig, CacheType
+from bitmind.cache.util.download import list_hf_files, download_files
+from bitmind.cache.util.filesystem import (
+ filter_ready_files,
+ wait_for_downloads_to_complete,
+ is_source_complete,
+)
+
+
+class BaseUpdater(ABC):
+ """
+ Base class for cache updaters that handle downloading and extracting data.
+
+ This version is designed to work with block callbacks rather than having
+ its own internal timing logic.
+ """
+
+ def __init__(
+ self,
+ cache_config: CacheConfig,
+ updater_config: CacheUpdaterConfig,
+ data_manager: DatasetRegistry,
+ ):
+ self.cache_fs = CacheFS(cache_config)
+ self.updater_config = updater_config
+ self.dataset_registry = data_manager
+ self._datasets = self._get_filtered_datasets()
+ self._recently_downloaded_files = []
+
+ def _get_filtered_datasets(
+ self,
+ modality: Optional[str] = None,
+ media_type: Optional[str] = None,
+ tags: Optional[List[str]] = None,
+ exclude_tags: Optional[List[str]] = None,
+ ) -> List[Any]:
+ """Get datasets that match the cache configuration"""
+ modality = self.cache_fs.config.modality if modality is None else modality
+ media_type = (
+ self.cache_fs.config.media_type if media_type is None else media_type
+ )
+ tags = self.cache_fs.config.tags if tags is None else tags
+
+ return self.dataset_registry.get_datasets(
+ modality=self.cache_fs.config.modality,
+ media_type=self.cache_fs.config.media_type,
+ tags=self.cache_fs.config.tags,
+ exclude_tags=exclude_tags,
+ )
+
+ @property
+ @abstractmethod
+ def media_file_extensions(self) -> List[str]:
+ pass
+
+ @property
+ @abstractmethod
+ def compressed_file_extension(self) -> str:
+ pass
+
+ @abstractmethod
+ async def _extract_items_from_source(
+ self, source_path: Path, count: int
+ ) -> List[Path]:
+ pass
+
+ async def initialize_cache(self) -> None:
+ """
+ This performs a one-time initialization to ensure the cache has
+ content available, particularly useful during first startup.
+ """
+ self.cache_fs._log_debug("Setting up cache")
+
+ if self.cache_fs.is_empty(CacheType.MEDIA):
+ if self.cache_fs.is_empty(CacheType.COMPRESSED):
+ self.cache_fs._log_debug("Compressed cache empty; populating")
+ await self.update_compressed_cache(
+ n_sources_per_dataset=1,
+ n_datasets=1,
+ exclude_tags=["large-zips"],
+ maybe_prune=False,
+ )
+
+ self.cache_fs._log_debug(
+ "Waiting for compressed files to finish downloading..."
+ )
+ await wait_for_downloads_to_complete(
+ self._recently_downloaded_files,
+ )
+ self._recently_downloaded_files = []
+
+ self.cache_fs._log_debug(
+ "Compressed files downloaded. Updating media cache."
+ )
+ await self.update_media_cache(maybe_prune=False)
+ else:
+ self.cache_fs._log_debug(
+ "Compressed sources available; Media cache empty; populating"
+ )
+ await self.update_media_cache()
+
+ async def update_compressed_cache(
+ self,
+ n_sources_per_dataset: Optional[int] = None,
+ n_datasets: Optional[int] = None,
+ exclude_tags: Optional[List[str]] = None,
+ maybe_prune: bool = True,
+ ) -> None:
+ """
+ Update the compressed cache by downloading new files.
+
+ Args:
+ n_sources_per_dataset: Optional override for number of sources per dataset
+ n_datasets: Optional limit on number of datasets to process
+ """
+ if n_sources_per_dataset is None:
+ n_sources_per_dataset = self.updater_config.num_sources_per_dataset
+
+ if maybe_prune:
+ await self.cache_fs.maybe_prune_cache(
+ cache_type=CacheType.COMPRESSED,
+ file_extensions=[self.compressed_file_extension],
+ )
+
+ # Reset tracking list before new downloads
+ self._recently_downloaded_files = []
+
+ datasets = self._get_filtered_datasets(exclude_tags=exclude_tags)
+ if n_datasets is not None and n_datasets > 0:
+ datasets = datasets[:n_datasets]
+ np.random.shuffle(datasets)
+
+ new_files = []
+ for dataset in datasets:
+ try:
+ filenames = self._list_remote_dataset_files(dataset.path)
+ if not filenames:
+ self.cache_fs._log_warning(f"No files found for {dataset.path}")
+ continue
+
+ remote_paths = self._get_download_urls(dataset.path, filenames)
+ to_download = self._select_files_to_download(
+ remote_paths, n_sources_per_dataset
+ )
+
+ output_dir = self.cache_fs.compressed_dir / dataset.path.split("/")[-1]
+
+ self.cache_fs._log_debug(
+ f"Downloading {len(to_download)} files from {dataset.path}"
+ )
+ batch_files = await self._download_files(to_download, output_dir)
+
+ # Track downloaded files
+ self._recently_downloaded_files.extend(batch_files)
+ new_files.extend(batch_files)
+ except Exception as e:
+ self.cache_fs._log_error(f"Error downloading from {dataset.path}: {e}")
+
+ if new_files:
+ self.cache_fs._log_debug(f"Added {len(new_files)} new compressed files")
+ else:
+ self.cache_fs._log_warning(f"No new files were added to compressed cache")
+
+ async def update_media_cache(
+ self, n_items_per_source: Optional[int] = None, maybe_prune: bool = True
+ ) -> None:
+ """
+ Update the media cache by extracting from compressed sources.
+
+ Args:
+ n_items_per_source: Optional override for number of items per source
+ """
+ if n_items_per_source is None:
+ n_items_per_source = self.updater_config.num_items_per_source
+
+ if maybe_prune:
+ await self.cache_fs.maybe_prune_cache(
+ cache_type=CacheType.MEDIA, file_extensions=self.media_file_extensions
+ )
+
+ all_compressed_files = self.cache_fs.get_files(
+ cache_type=CacheType.COMPRESSED,
+ file_extensions=[self.compressed_file_extension],
+ use_index=False,
+ )
+
+ if not all_compressed_files:
+ self.cache_fs._log_warning(f"No compressed sources available")
+ return
+
+ compressed_files = filter_ready_files(all_compressed_files)
+
+ if not compressed_files:
+ self.cache_fs._log_warning(
+ f"No ready compressed sources available. Files may still be downloading."
+ )
+ return
+
+ valid_compressed_files = []
+ for path in compressed_files:
+ if not is_source_complete(path):
+ try:
+ Path(path).unlink()
+ except Exception as del_err:
+ self.cache_fs._log_error(
+ f"Failed to delete corrupted file {path}: {del_err}"
+ )
+ else:
+ valid_compressed_files.append(path)
+
+ if len(valid_compressed_files) > 10:
+ valid_compressed_files = np.random.choice(
+ valid_compressed_files, size=10, replace=False
+ ).tolist()
+
+ new_files = []
+ for source in valid_compressed_files:
+ try:
+ items = await self._extract_items_from_source(
+ source, n_items_per_source
+ )
+ new_files.extend(items)
+ except Exception as e:
+ self.cache_fs._log_error(f"Error extracting from {source}: {e}")
+
+ if new_files:
+ self.cache_fs._log_debug(f"Added {len(new_files)} new items to media cache")
+ else:
+ self.cache_fs._log_warning(f"No new items were added to media cache")
+
+ def num_media_files(self) -> int:
+ count = self.cache_fs.num_files(CacheType.MEDIA, self.media_file_extensions)
+ return count == 0
+
+ def num_compressed_files(self) -> int:
+ count = self.cache_fs.num_files(
+ CacheType.COMPRESSED, [self.compressed_file_extension]
+ )
+ return count == 0
+
+ def _select_files_to_download(self, urls: List[str], count: int) -> List[str]:
+ """Select random files to download"""
+ return np.random.choice(
+ urls, size=min(count, len(urls)), replace=False
+ ).tolist()
+
+ def _list_remote_dataset_files(self, dataset_path: str) -> List[str]:
+ """List available files in a dataset with the parquet extension"""
+ return list_hf_files(
+ repo_id=dataset_path, extension=self.compressed_file_extension
+ )
+
+ def _get_download_urls(self, dataset_path: str, filenames: List[str]) -> List[str]:
+ """Get Hugging Face download URLs for data files"""
+ return [
+ f"https://huggingface.co/datasets/{dataset_path}/resolve/main/{f}"
+ for f in filenames
+ ]
+
+ async def _download_files(self, urls: List[str], output_dir: Path) -> List[Path]:
+ """Download a subset of a remote dataset's compressed files"""
+ return await download_files(urls, output_dir)
diff --git a/bitmind/cache/updater/image_updater.py b/bitmind/cache/updater/image_updater.py
new file mode 100644
index 00000000..a195eeee
--- /dev/null
+++ b/bitmind/cache/updater/image_updater.py
@@ -0,0 +1,78 @@
+from pathlib import Path
+from typing import List
+import traceback
+
+from bitmind.cache.updater import BaseUpdater
+from bitmind.cache.datasets import DatasetRegistry
+from bitmind.cache.util.filesystem import is_parquet_complete
+from bitmind.types import CacheUpdaterConfig, CacheConfig
+
+
+class ImageUpdater(BaseUpdater):
+ """
+ Updater for image data from parquet files.
+
+ This class handles downloading parquet files from Hugging Face datasets
+ and extracting images from them into the media cache.
+ """
+
+ def __init__(
+ self,
+ cache_config: CacheConfig,
+ updater_config: CacheUpdaterConfig,
+ data_manager: DatasetRegistry,
+ ):
+ super().__init__(
+ cache_config=cache_config,
+ updater_config=updater_config,
+ data_manager=data_manager,
+ )
+
+ @property
+ def media_file_extensions(self) -> List[str]:
+ """List of file extensions supported by this updater"""
+ return [".jpg", ".jpeg", ".png", ".webp"]
+
+ @property
+ def compressed_file_extension(self) -> str:
+ """File extension for compressed source files"""
+ return ".parquet"
+
+ async def _extract_items_from_source(
+ self, source_path: Path, count: int
+ ) -> List[Path]:
+ """
+ Extract images from a parquet file.
+
+ Args:
+ source_path: Path to the parquet file
+ count: Number of images to extract
+
+ Returns:
+ List of paths to extracted image files
+ """
+ self.cache_fs._log_trace(f"Extracting up to {count} images from {source_path}")
+
+ dataset_name = source_path.parent.name
+ if not dataset_name:
+ dataset_name = source_path.stem
+
+ dest_dir = self.cache_fs.cache_dir / dataset_name
+ dest_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ from ..util import extract_images_from_parquet
+
+ saved_files = extract_images_from_parquet(
+ parquet_path=source_path, dest_dir=dest_dir, num_images=count
+ )
+
+ self.cache_fs._log_trace(
+ f"Extracted {len(saved_files)} images from {source_path}"
+ )
+ return [Path(f) for f in saved_files]
+
+ except Exception as e:
+ self.cache_fs._log_error(f"Error extracting images from {source_path}: {e}")
+ self.cache_fs._log_error(traceback.format_exc())
+ return []
diff --git a/bitmind/cache/updater/updater_registry.py b/bitmind/cache/updater/updater_registry.py
new file mode 100644
index 00000000..49686e0b
--- /dev/null
+++ b/bitmind/cache/updater/updater_registry.py
@@ -0,0 +1,29 @@
+from typing import Dict, Optional
+
+import bittensor as bt
+
+from bitmind.cache.updater import BaseUpdater
+
+
+class UpdaterRegistry:
+ """
+ Registry for cache updaters.
+ """
+
+ def __init__(self):
+ self._updaters: Dict[str, BaseUpdater] = {}
+
+ def register(self, name: str, updater: BaseUpdater) -> None:
+ if name in self._updaters:
+ bt.logging.warning(f"Updater {name} already registered, will be replaced")
+ self._updaters[name] = updater
+
+ def get(self, name: str) -> Optional[BaseUpdater]:
+ return self._updaters.get(name)
+
+ def get_all(self) -> Dict[str, BaseUpdater]:
+ return dict(self._updaters)
+
+ def deregister(self, name: str) -> None:
+ if name in self._updaters:
+ del self._updaters[name]
diff --git a/bitmind/cache/updater/video_updater.py b/bitmind/cache/updater/video_updater.py
new file mode 100644
index 00000000..91a3a9d2
--- /dev/null
+++ b/bitmind/cache/updater/video_updater.py
@@ -0,0 +1,83 @@
+import zipfile
+from pathlib import Path
+from typing import List
+
+from bitmind.types import CacheUpdaterConfig, CacheConfig
+from bitmind.cache.updater import BaseUpdater
+from bitmind.cache.datasets import DatasetRegistry
+
+
+class VideoUpdater(BaseUpdater):
+ """
+ Updater for video data from zip files.
+
+ This class handles downloading zip files from Hugging Face datasets
+ and extracting videos from them into the media cache.
+ """
+
+ def __init__(
+ self,
+ cache_config: CacheConfig,
+ updater_config: CacheUpdaterConfig,
+ data_manager: DatasetRegistry,
+ ):
+ super().__init__(
+ cache_config=cache_config,
+ updater_config=updater_config,
+ data_manager=data_manager,
+ )
+
+ @property
+ def media_file_extensions(self) -> List[str]:
+ """List of file extensions supported by this updater"""
+ return [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".webm"]
+
+ @property
+ def compressed_file_extension(self) -> str:
+ """File extension for compressed source files"""
+ return ".zip"
+
+ async def _extract_items_from_source(
+ self, source_path: Path, count: int
+ ) -> List[Path]:
+ """
+ Extract videos from a zip file.
+
+ Args:
+ source_path: Path to the zip file
+ count: Number of videos to extract
+
+ Returns:
+ List of paths to extracted video files
+ """
+ self.cache_fs._log_trace(f"Extracting up to {count} videos from {source_path}")
+
+ dataset_name = source_path.parent.name
+ if not dataset_name:
+ dataset_name = source_path.stem
+
+ dest_dir = self.cache_fs.cache_dir / dataset_name
+ dest_dir.mkdir(parents=True, exist_ok=True)
+
+ try:
+ from ..util import extract_videos_from_zip
+
+ extracted_pairs = extract_videos_from_zip(
+ zip_path=source_path,
+ dest_dir=dest_dir,
+ num_videos=count,
+ file_extensions=set(self.media_file_extensions),
+ )
+
+ # extract_videos_from_zip returns pairs of (video_path, metadata_path)
+ # We just need the video paths for our return value
+ video_paths = [Path(pair[0]) for pair in extracted_pairs]
+
+ self.cache_fs._log_trace(
+ f"Extracted {len(video_paths)} videos from {source_path}"
+ )
+ return video_paths
+
+ except Exception as e:
+ self.cache_fs._log_trace(f"Error extracting videos from {source_path}: {e}")
+ return []
diff --git a/bitmind/cache/util/__init__.py b/bitmind/cache/util/__init__.py
new file mode 100644
index 00000000..1b268345
--- /dev/null
+++ b/bitmind/cache/util/__init__.py
@@ -0,0 +1,42 @@
+from bitmind.cache.util.filesystem import (
+ is_source_complete,
+ is_zip_complete,
+ is_parquet_complete,
+ get_most_recent_update_time,
+)
+
+from bitmind.cache.util.download import (
+ download_files,
+ list_hf_files,
+ openvid1m_err_handler,
+)
+
+from bitmind.cache.util.video import (
+ get_video_duration,
+ get_video_metadata,
+ seconds_to_str,
+)
+
+from bitmind.cache.util.extract import (
+ extract_videos_from_zip,
+ extract_images_from_parquet,
+)
+
+__all__ = [
+ # Filesystem
+ "is_source_complete",
+ "is_zip_complete",
+ "is_parquet_complete",
+ "get_most_recent_update_time",
+ # Download
+ "download_files",
+ "list_hf_files",
+ "openvid1m_err_handler",
+ # Video
+ "get_video_duration",
+ "get_video_metadata",
+ "seconds_to_str",
+ # Extraction
+ "extract_videos_from_zip",
+ "extract_images_from_parquet",
+]
diff --git a/bitmind/validator/cache/download.py b/bitmind/cache/util/download.py
similarity index 55%
rename from bitmind/validator/cache/download.py
rename to bitmind/cache/util/download.py
index 6022e2b2..da5e33c4 100644
--- a/bitmind/validator/cache/download.py
+++ b/bitmind/cache/util/download.py
@@ -1,69 +1,111 @@
-import requests
import os
+import traceback
from pathlib import Path
-from requests.exceptions import RequestException
-from typing import List, Union, Dict, Optional
+from typing import List, Union, Optional
+import asyncio
+import aiohttp
import bittensor as bt
import huggingface_hub as hf_hub
+from requests.exceptions import RequestException
-def download_files(
- urls: List[str],
- output_dir: Union[str, Path],
- chunk_size: int = 8192
-) -> List[Path]:
+def list_hf_files(repo_id, repo_type="dataset", extension=None):
+ """List files from a Hugging Face repository.
+
+ Args:
+ repo_id: Repository ID
+ repo_type: Type of repository ('dataset', 'model', etc.)
+ extension: Filter files by extension
+
+ Returns:
+ List of files in the repository
"""
- Downloads multiple files synchronously.
-
+ files = []
+ try:
+ files = list(hf_hub.list_repo_files(repo_id=repo_id, repo_type=repo_type))
+ if extension:
+ files = [f for f in files if f.endswith(extension)]
+ except Exception as e:
+ bt.logging.error(f"Failed to list files of type {extension} in {repo_id}: {e}")
+ return files
+
+
+async def download_files(
+ urls: List[str], output_dir: Union[str, Path], chunk_size: int = 8192
+) -> List[Path]:
+ """Download multiple files asynchronously.
+
Args:
urls: List of URLs to download
output_dir: Directory to save the files
chunk_size: Size of chunks to download at a time
-
+
Returns:
List of successfully downloaded file paths
"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
- downloaded_files = []
- for url in urls:
- try:
- bt.logging.info(f'Downloading {url}')
- response = requests.get(url, stream=True)
- if response.status_code != 200:
- bt.logging.error(f'Failed to download {url}: Status {response.status_code}')
- continue
+ download_tasks = []
+ timeout = aiohttp.ClientTimeout(
+ total=3600,
+ )
- filename = os.path.basename(url)
- filepath = output_dir / filename
+ async with aiohttp.ClientSession(timeout=timeout) as session:
+ # Create download tasks for each URL
+ for url in urls:
+ download_tasks.append(
+ download_single_file(session, url, output_dir, chunk_size)
+ )
- bt.logging.info(f'Writing to {filepath}')
- with open(filepath, 'wb') as f:
- for chunk in response.iter_content(chunk_size=chunk_size):
- if chunk: # filter out keep-alive chunks
- f.write(chunk)
+ # Run all downloads concurrently and gather results
+ downloaded_files = await asyncio.gather(*download_tasks, return_exceptions=True)
- downloaded_files.append(filepath)
- bt.logging.info(f'Successfully downloaded {filename}')
+ # Filter out exceptions and return only successful downloads
+ return [f for f in downloaded_files if isinstance(f, Path)]
- except Exception as e:
- bt.logging.error(f'Error downloading {url}: {str(e)}')
- continue
- return downloaded_files
+async def download_single_file(
+ session: aiohttp.ClientSession, url: str, output_dir: Path, chunk_size: int
+) -> Path:
+ """Download a single file asynchronously.
+ Args:
+ session: aiohttp ClientSession to use for requests
+ url: URL to download
+ output_dir: Directory to save the file
+ chunk_size: Size of chunks to download at a time
-def list_hf_files(repo_id, repo_type='dataset', extension=None):
- files = []
+ Returns:
+ Path to the downloaded file
+ """
try:
- files = list(hf_hub.list_repo_files(repo_id=repo_id, repo_type=repo_type))
- if extension:
- files = [f for f in files if f.endswith(extension)]
+ bt.logging.info(f"Downloading {url}")
+
+ async with session.get(url) as response:
+ if response.status != 200:
+ bt.logging.error(f"Failed to download {url}: Status {response.status}")
+ raise Exception(f"HTTP error {response.status}")
+
+ filename = os.path.basename(url)
+ filepath = output_dir / filename
+
+ bt.logging.info(f"Writing to {filepath}")
+
+ # Use async file I/O to write the file
+ with open(filepath, "wb") as f:
+ # Download and write in chunks
+ async for chunk in response.content.iter_chunked(chunk_size):
+ if chunk: # filter out keep-alive chunks
+ f.write(chunk)
+
+ return filepath
+
except Exception as e:
- bt.logging.error(f"Failed to list files of type {extension} in {repo_id}: {e}")
- return files
+ bt.logging.error(f"Error downloading {url}: {str(e)}")
+ bt.logging.error(traceback.format_exc())
+ raise
def openvid1m_err_handler(
@@ -71,24 +113,23 @@ def openvid1m_err_handler(
output_path: Path,
part_index: int,
chunk_size: int = 8192,
- timeout: int = 300
+ timeout: int = 300,
) -> Optional[Path]:
- """
- Synchronous error handler for OpenVid1M downloads that handles split files.
-
+ """Synchronous error handler for OpenVid1M downloads that handles split files.
+
Args:
base_zip_url: Base URL for the zip parts
output_path: Directory to save files
part_index: Index of the part to download
chunk_size: Size of download chunks
timeout: Download timeout in seconds
-
+
Returns:
Path to combined file if successful, None otherwise
"""
part_urls = [
f"{base_zip_url}{part_index}_partaa",
- f"{base_zip_url}{part_index}_partab"
+ f"{base_zip_url}{part_index}_partab",
]
error_log_path = output_path / "download_log.txt"
downloaded_parts = []
@@ -96,27 +137,27 @@ def openvid1m_err_handler(
# Download each part
for part_url in part_urls:
part_file_path = output_path / Path(part_url).name
-
+
if part_file_path.exists():
bt.logging.warning(f"File {part_file_path} exists.")
downloaded_parts.append(part_file_path)
continue
-
+
try:
response = requests.get(part_url, stream=True, timeout=timeout)
if response.status_code != 200:
raise RequestException(
f"HTTP {response.status_code}: {response.reason}"
)
-
- with open(part_file_path, 'wb') as f:
+
+ with open(part_file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk: # filter out keep-alive chunks
f.write(chunk)
-
+
bt.logging.info(f"File {part_url} saved to {part_file_path}")
downloaded_parts.append(part_file_path)
-
+
except Exception as e:
error_message = f"File {part_url} download failed: {str(e)}\n"
bt.logging.error(error_message)
@@ -129,23 +170,25 @@ def openvid1m_err_handler(
combined_file = output_path / f"OpenVid_part{part_index}.zip"
combined_data = bytearray()
for part_path in downloaded_parts:
- with open(part_path, 'rb') as part_file:
+ with open(part_path, "rb") as part_file:
combined_data.extend(part_file.read())
-
- with open(combined_file, 'wb') as out_file:
+
+ with open(combined_file, "wb") as out_file:
out_file.write(combined_data)
-
+
for part_path in downloaded_parts:
part_path.unlink()
-
+
bt.logging.info(f"Successfully combined parts into {combined_file}")
return combined_file
-
+
except Exception as e:
- error_message = f"Failed to combine parts for index {part_index}: {str(e)}\n"
+ error_message = (
+ f"Failed to combine parts for index {part_index}: {str(e)}\n"
+ )
bt.logging.error(error_message)
with open(error_log_path, "a") as error_log_file:
error_log_file.write(error_message)
return None
-
+
return None
diff --git a/bitmind/validator/cache/extract.py b/bitmind/cache/util/extract.py
similarity index 59%
rename from bitmind/validator/cache/extract.py
rename to bitmind/cache/util/extract.py
index 00b630c5..7b0c7e25 100644
--- a/bitmind/validator/cache/extract.py
+++ b/bitmind/cache/util/extract.py
@@ -1,33 +1,29 @@
import base64
import hashlib
import json
-import logging
-import mimetypes
import os
import random
-import warnings
import shutil
from datetime import datetime
from io import BytesIO
from pathlib import Path
-from typing import Dict, List, Optional, Set, Tuple
+from typing import List, Optional, Set, Tuple
from zipfile import ZipFile
-from PIL import Image
-import pyarrow.parquet as pq
import bittensor as bt
+import pyarrow.parquet as pq
+from PIL import Image
def extract_videos_from_zip(
zip_path: Path,
dest_dir: Path,
num_videos: int,
- file_extensions: Set[str] = {'.mp4', '.avi', '.mov', '.mkv', '.wmv'},
- include_checksums: bool = True
+ file_extensions: Set[str] = {".mp4", ".avi", ".mov", ".mkv", ".wmv"},
+ include_checksums: bool = True,
) -> List[Tuple[str, str]]:
- """
- Extract random videos and their metadata from a zip file and save them to disk.
-q
+ """Extract random videos and their metadata from a zip file and save them to disk.
+
Args:
zip_path: Path to the zip file
dest_dir: Directory to save videos and metadata
@@ -45,64 +41,63 @@ def extract_videos_from_zip(
try:
with ZipFile(zip_path) as zip_file:
video_files = [
- f for f in zip_file.namelist()
+ f
+ for f in zip_file.namelist()
if any(f.lower().endswith(ext) for ext in file_extensions)
+ and "MACOSX" not in f
]
if not video_files:
bt.logging.warning(f"No video files found in {zip_path}")
return extracted_files
- bt.logging.info(f"{len(video_files)} video files found in {zip_path}")
+ bt.logging.debug(f"{len(video_files)} video files found in {zip_path}")
selected_videos = random.sample(
- video_files,
- min(num_videos, len(video_files))
+ video_files, min(num_videos, len(video_files))
)
- bt.logging.info(f"Extracting {len(selected_videos)} randomly sampled video files from {zip_path}")
- for idx, video in enumerate(selected_videos):
- if 'MACOSX' in video:
- continue
+ bt.logging.debug(
+ f"Extracting {len(selected_videos)} randomly sampled video files from {zip_path}"
+ )
+ for video in selected_videos:
try:
# extract video and get metadata
- video_path = dest_dir / Path(video).name
+ video_path = dest_dir / Path(video).name
with zip_file.open(video) as source:
- with open(video_path, 'wb') as target:
+ with open(video_path, "wb") as target:
shutil.copyfileobj(source, target)
video_info = zip_file.getinfo(video)
metadata = {
- 'dataset': str(Path(zip_path).parent.name),
- 'source_zip': str(zip_path),
- 'path_in_zip': video,
- 'extraction_date': datetime.now().isoformat(),
- 'file_size': os.path.getsize(video_path),
- 'zip_metadata': {
- 'compress_size': video_info.compress_size,
- 'file_size': video_info.file_size,
- 'compress_type': video_info.compress_type,
- 'date_time': datetime.strftime(
- datetime(*video_info.date_time),
- '%Y-%m-%d %H:%M:%S'
+ "dataset": Path(zip_path).parent.name,
+ "source_zip": str(zip_path),
+ "path_in_zip": video,
+ "extraction_date": datetime.now().isoformat(),
+ "file_size": os.path.getsize(video_path),
+ "zip_metadata": {
+ "compress_size": video_info.compress_size,
+ "file_size": video_info.file_size,
+ "compress_type": video_info.compress_type,
+ "date_time": datetime.strftime(
+ datetime(*video_info.date_time), "%Y-%m-%d %H:%M:%S"
),
- }
+ },
}
if include_checksums:
- with open(video_path, 'rb') as f:
+ with open(video_path, "rb") as f:
file_data = f.read()
- metadata['checksums'] = {
- 'md5': hashlib.md5(file_data).hexdigest(),
- 'sha256': hashlib.sha256(file_data).hexdigest()
+ metadata["checksums"] = {
+ "md5": hashlib.md5(file_data).hexdigest(),
+ "sha256": hashlib.sha256(file_data).hexdigest(),
}
metadata_filename = f"{video_path.stem}.json"
metadata_path = dest_dir / metadata_filename
- with open(metadata_path, 'w', encoding='utf-8') as f:
+ with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
extracted_files.append((str(video_path), str(metadata_path)))
- logging.info(f"Extracted {Path(video).name} from {zip_path}")
except Exception as e:
bt.logging.warning(f"Error extracting {video}: {e}")
@@ -115,23 +110,18 @@ def extract_videos_from_zip(
def extract_images_from_parquet(
- parquet_path: Path,
- dest_dir: Path,
- num_images: int,
- seed: Optional[int] = None
-) -> List[Tuple[str, str]]:
- """
- Extract random images and their metadata from a parquet file and save them to disk.
+ parquet_path: Path, dest_dir: Path, num_images: int, seed: Optional[int] = None
+) -> List[str]:
+ """Extract random images and their metadata from a parquet file and save them to disk.
Args:
parquet_path: Path to the parquet file
dest_dir: Directory to save images and metadata
num_images: Number of images to extract
- columns: Specific columns to include in metadata
seed: Random seed for sampling
Returns:
- List of tuples containing (image_path, metadata_path)
+ List of image file paths
"""
dest_dir = Path(dest_dir)
dest_dir.mkdir(parents=True, exist_ok=True)
@@ -140,7 +130,7 @@ def extract_images_from_parquet(
table = pq.read_table(parquet_path)
df = table.to_pandas()
sample_df = df.sample(n=min(num_images, len(df)), random_state=seed)
- image_col = next((col for col in sample_df.columns if 'image' in col.lower()), None)
+ image_col = next((col for col in sample_df.columns if "image" in col.lower()), None)
metadata_cols = [c for c in sample_df.columns if c != image_col]
saved_files = []
@@ -149,28 +139,35 @@ def extract_images_from_parquet(
try:
img_data = row[image_col]
if isinstance(img_data, dict):
- key = next((k for k in img_data if 'bytes' in k.lower() or 'image' in k.lower()), None)
+ key = next(
+ (
+ k
+ for k in img_data
+ if "bytes" in k.lower() or "image" in k.lower()
+ ),
+ None,
+ )
img_data = img_data[key]
try:
img = Image.open(BytesIO(img_data))
- except Exception as e:
+ except Exception:
img_data = base64.b64decode(img_data)
img = Image.open(BytesIO(img_data))
base_filename = f"{parquet_prefix}__image_{idx}"
- image_format = img.format.lower() if img.format else 'png'
+ image_format = img.format.lower() if img.format else "png"
img_filename = f"{base_filename}.{image_format}"
img_path = dest_dir / img_filename
img.save(img_path)
metadata = {
- 'dataset': str(Path(parquet_path).parent.name),
- 'source_parquet': str(parquet_path),
- 'original_index': str(idx),
- 'image_format': image_format,
- 'image_size': img.size,
- 'image_mode': img.mode
+ "dataset": Path(parquet_path).parent.name,
+ "source_parquet": str(parquet_path),
+ "original_index": str(idx),
+ "image_format": image_format,
+ "image_size": img.size,
+ "image_mode": img.mode,
}
for col in metadata_cols:
@@ -180,16 +177,16 @@ def extract_images_from_parquet(
metadata[col] = row[col]
except (TypeError, OverflowError):
metadata[col] = str(row[col])
-
+
metadata_filename = f"{base_filename}.json"
metadata_path = dest_dir / metadata_filename
- with open(metadata_path, 'w', encoding='utf-8') as f:
+ with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
-
+
saved_files.append(str(img_path))
except Exception as e:
bt.logging.warning(f"Failed to extract/save image {idx}: {e}")
continue
- return saved_files
\ No newline at end of file
+ return saved_files
diff --git a/bitmind/cache/util/filesystem.py b/bitmind/cache/util/filesystem.py
new file mode 100644
index 00000000..6521d84c
--- /dev/null
+++ b/bitmind/cache/util/filesystem.py
@@ -0,0 +1,334 @@
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union, Any
+import bittensor as bt
+import pyarrow.parquet as pq
+from zipfile import ZipFile, BadZipFile
+import asyncio
+import sys
+import time
+
+from bitmind.types import FileType
+
+
+def get_most_recent_update_time(directory: Path) -> float:
+ """Get the most recent modification time of any file in directory."""
+ try:
+ mtimes = [f.stat().st_mtime for f in directory.iterdir()]
+ return max(mtimes) if mtimes else 0
+ except Exception as e:
+ bt.logging.error(f"Error getting modification times: {e}")
+ return 0
+
+
+def is_source_complete(path: Union[str, Path]) -> Callable[[Path], bool]:
+ """Checks integrity of parquet or zip file"""
+
+ path = Path(path)
+ if path.suffix.lower() == ".parquet":
+ return is_parquet_complete(path)
+ elif path.suffix.lower() == ".zip":
+ return is_zip_complete(path)
+ else:
+ return None
+
+
+def is_zip_complete(zip_path: Union[str, Path], testzip=False) -> bool:
+ try:
+ with ZipFile(zip_path) as zf:
+ if testzip:
+ zf.testzip()
+ else:
+ zf.namelist()
+ return True
+ except (BadZipFile, Exception) as e:
+ bt.logging.error(f"Zip file {zip_path} is invalid: {e}")
+ return False
+
+
+def is_parquet_complete(path: Path) -> bool:
+ try:
+ with open(path, "rb") as f:
+ pq.read_metadata(f)
+ return True
+ except Exception as e:
+ bt.logging.error(f"Parquet file {path} is incomplete or corrupted: {e}")
+ return False
+
+
+def get_dir_size(
+ path: Union[str, Path], exclude_dirs: Optional[List[str]] = None
+) -> Tuple[int, int]:
+ if exclude_dirs is None:
+ exclude_dirs = []
+
+ total_size = 0
+ file_count = 0
+ path_obj = Path(path)
+
+ try:
+ for item in path_obj.iterdir():
+ if item.is_dir() and item.name in exclude_dirs:
+ continue
+ elif item.is_file():
+ try:
+ total_size += item.stat().st_size
+ file_count += 1
+ except (OSError, PermissionError):
+ pass
+ elif item.is_dir():
+ subdir_size, subdir_count = get_dir_size(item, exclude_dirs)
+ total_size += subdir_size
+ file_count += subdir_count
+ except (PermissionError, OSError) as e:
+ print(f"Error accessing {path}: {e}", file=sys.stderr)
+
+ return total_size, file_count
+
+
+def scale_size(size: float, from_unit: str = "B", to_unit: str = "GB") -> float:
+ if size == 0:
+ return 0.0
+
+ units = ["B", "KB", "MB", "GB", "TB", "PB"]
+ from_unit, to_unit = from_unit.upper(), to_unit.upper()
+ if from_unit not in units or to_unit not in units:
+ raise ValueError(f"Units must be one of: {', '.join(units)}")
+
+ from_index = units.index(from_unit)
+ to_index = units.index(to_unit)
+ scale_factor = from_index - to_index
+
+ if scale_factor > 0:
+ return size * (1024**scale_factor)
+ elif scale_factor < 0:
+ return size / (1024 ** abs(scale_factor))
+ return size
+
+
+def format_size(
+ size: float, from_unit: str = "B", to_unit: Optional[str] = None
+) -> str:
+ if size == 0:
+ return "0 B"
+
+ units = ["B", "KB", "MB", "GB", "TB", "PB"]
+ from_unit = from_unit.upper()
+
+ if from_unit not in units:
+ raise ValueError(f"From unit must be one of: {', '.join(units)}")
+
+ if to_unit is None:
+ current_size = scale_size(size, from_unit, "B")
+ unit_index = 0
+
+ while current_size >= 1024 and unit_index < len(units) - 1:
+ current_size /= 1024
+ unit_index += 1
+
+ return f"{current_size:.2f} {units[unit_index]}"
+ else:
+ to_unit = to_unit.upper()
+ if to_unit not in units:
+ raise ValueError(f"To unit must be one of: {', '.join(units)}")
+ scaled_size = scale_size(size, from_unit, to_unit)
+ return f"{scaled_size:.2f} {to_unit}"
+
+
+def analyze_directory(
+ root_path: Union[str, Path],
+ exclude_dirs: Optional[List[str]] = None,
+ min_file_count: int = 1,
+ log_func=None,
+) -> Dict[str, Any]:
+ if exclude_dirs is None:
+ exclude_dirs = []
+
+ path_obj = Path(root_path)
+ result = {
+ "name": path_obj.name or str(path_obj),
+ "path": str(path_obj),
+ "subdirs": [],
+ "excluded_dirs": [],
+ }
+
+ size, count = get_dir_size(path_obj, exclude_dirs)
+ result["size"] = size
+ result["count"] = count
+
+ try:
+ subdirs = [d for d in path_obj.iterdir() if d.is_dir()]
+
+ for subdir in sorted(subdirs):
+ if subdir.name in exclude_dirs:
+ _, excluded_count = get_dir_size(subdir, [])
+ if excluded_count < min_file_count:
+ continue
+
+ excluded_data = analyze_directory(subdir, [], min_file_count, log_func)
+ excluded_data["excluded"] = True
+ result["excluded_dirs"].append(excluded_data)
+ else:
+ subdir_data = analyze_directory(
+ subdir, exclude_dirs, min_file_count, log_func
+ )
+ if subdir_data["count"] < min_file_count:
+ continue
+
+ result["subdirs"].append(subdir_data)
+ except (PermissionError, OSError) as e:
+ error_msg = f"Error accessing {path_obj}: {e}"
+ if log_func:
+ log_func(error_msg)
+ else:
+ print(error_msg, file=sys.stderr)
+
+ return result
+
+
+def print_directory_tree(
+ tree_data: Dict[str, Any],
+ indent: str = "",
+ is_last: bool = True,
+ prefix: str = "",
+ log_func=None,
+) -> None:
+ if (
+ tree_data["count"] == 0
+ and not tree_data["subdirs"]
+ and not tree_data["excluded_dirs"]
+ ):
+ return
+
+ if is_last:
+ branch = "└── "
+ next_indent = indent + " "
+ else:
+ branch = "├── "
+ next_indent = indent + "│ "
+
+ name = tree_data["name"]
+ count = tree_data["count"]
+ size = scale_size(tree_data["size"])
+
+ tree_line = f"{indent}{prefix}{branch}[{name}] - {count} files, {size}"
+ if log_func:
+ log_func(tree_line)
+ else:
+ print(tree_line)
+
+ num_subdirs = len(tree_data["subdirs"])
+
+ for i, subdir in enumerate(tree_data["subdirs"]):
+ is_subdir_last = (i == num_subdirs - 1) and not tree_data["excluded_dirs"]
+ print_directory_tree(subdir, next_indent, is_subdir_last, "", log_func)
+
+ for i, excluded in enumerate(tree_data["excluded_dirs"]):
+ is_excluded_last = i == len(tree_data["excluded_dirs"]) - 1
+ print_directory_tree(
+ excluded, next_indent, is_excluded_last, "(SOURCE) ", log_func
+ )
+
+
+def is_file_older_than(file_path: Union[str, Path], seconds: float = 1.0) -> bool:
+ """Check if a file's last modification time is older than specified seconds."""
+ try:
+ mtime = Path(file_path).stat().st_mtime
+ return (time.time() - mtime) >= seconds
+ except (FileNotFoundError, PermissionError):
+ return False
+
+
+def has_stable_size(file_path: Union[str, Path], wait_time: float = 0.1) -> bool:
+ """Check if a file's size is stable (not changing)."""
+ path = Path(file_path)
+ try:
+ size1 = path.stat().st_size
+ time.sleep(wait_time)
+ size2 = path.stat().st_size
+ return size1 == size2
+ except (FileNotFoundError, PermissionError):
+ return False
+
+
+def is_file_locked(file_path: Union[str, Path]) -> bool:
+ """Check if a file is locked (being written to by another process)."""
+ try:
+ with open(file_path, "rb+") as _:
+ pass
+ return False
+ except (PermissionError, OSError):
+ return True
+
+
+def is_file_ready(
+ file_path: Union[str, Path],
+ min_age_seconds: float = 1.0,
+ check_size_stability: bool = False,
+ check_file_lock: bool = True,
+ stability_wait_time: float = 0.1,
+) -> bool:
+ """
+ Determine if a file is ready for processing (not being downloaded/written to).
+
+ Args:
+ file_path: Path to the file to check
+ min_age_seconds: Minimum age in seconds since last modification
+ check_size_stability: Whether to check if file size is stable
+ check_file_lock: Whether to check if file is locked by another process
+ stability_wait_time: Time to wait when checking size stability
+
+ Returns:
+ bool: True if the file appears ready for processing
+ """
+ file_path = Path(file_path) if isinstance(file_path, str) else file_path
+
+ if not file_path.exists() or not file_path.is_file():
+ return False
+
+ if not is_file_older_than(file_path, min_age_seconds):
+ return False
+
+ if check_size_stability and not has_stable_size(file_path, stability_wait_time):
+ return False
+
+ if check_file_lock and is_file_locked(file_path):
+ return False
+
+ return True
+
+
+def filter_ready_files(
+ file_list: List[Union[str, Path]], **kwargs
+) -> List[Union[str, Path]]:
+ """
+ Filter a list of files to only include those that are ready for processing.
+
+ Args:
+ file_list: List of file paths
+ **kwargs: Additional arguments to pass to is_file_ready()
+
+ Returns:
+ list: Filtered list containing only ready files
+ """
+ return [f for f in file_list if is_file_ready(f, **kwargs)]
+
+
+async def wait_for_downloads_to_complete(
+ files: List[Path], min_age_seconds: float = 2.0, timeout_seconds: int = 180
+) -> bool:
+ if not files:
+ return True
+
+ start_time = time.time()
+
+ while time.time() - start_time < timeout_seconds:
+ ready_files = filter_ready_files(
+ file_list=files, min_age_seconds=min_age_seconds
+ )
+ if len(ready_files) == len(files):
+ return True
+ # yield to event loop
+ await asyncio.sleep(5)
+
+ bt.logging.error(f"Timeout waiting for {files} after {timeout_seconds} seconds")
+ return False
diff --git a/bitmind/cache/util/video.py b/bitmind/cache/util/video.py
new file mode 100644
index 00000000..0e3230a3
--- /dev/null
+++ b/bitmind/cache/util/video.py
@@ -0,0 +1,212 @@
+import json
+import subprocess
+import math
+import ffmpeg
+from typing import Dict, Any, Optional, Union, Tuple
+
+
+def get_video_duration(video_path: str) -> float:
+ """Get the duration of a video file in seconds.
+
+ Args:
+ video_path: Path to the video file
+
+ Returns:
+ Duration in seconds
+
+ Raises:
+ Exception: If the duration cannot be determined
+ """
+ try:
+ probe = ffmpeg.probe(video_path)
+ duration = float(probe["format"]["duration"])
+ return duration
+ except Exception as e:
+ try:
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "error",
+ "-show_entries",
+ "format=duration",
+ "-of",
+ "json",
+ video_path,
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ )
+ data = json.loads(result.stdout)
+ duration = float(data["format"]["duration"])
+ return duration
+ except Exception as sub_e:
+ raise Exception(f"Failed to get video duration: {e}, {sub_e}")
+
+
+def get_video_metadata(video_path: str, max_fps: float = 30.0) -> Dict[str, Any]:
+ """Get comprehensive metadata from a video file with sanity checks.
+
+ Args:
+ video_path: Path to the video file
+ max_fps: Maximum reasonable FPS value (default: 60.0)
+
+ Returns:
+ Dictionary containing metadata with sanity-checked values
+ """
+ try:
+ ffprobe_fields = (
+ "format=duration,size,bit_rate,format_name:"
+ "stream=width,height,codec_name,codec_type,"
+ "r_frame_rate,avg_frame_rate,pix_fmt,sample_rate,channels"
+ )
+ result = subprocess.run(
+ [
+ "ffprobe",
+ "-v",
+ "error",
+ "-show_entries",
+ ffprobe_fields,
+ "-of",
+ "json",
+ video_path,
+ ],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ check=True, # This will raise CalledProcessError if ffprobe fails
+ )
+
+ data = json.loads(result.stdout)
+
+ # Extract basic format information
+ format_info = data.get("format", {})
+ streams = data.get("streams", [])
+
+ # Find video and audio streams
+ video_stream = next(
+ (s for s in streams if s.get("codec_type") == "video"), None
+ )
+ audio_stream = next(
+ (s for s in streams if s.get("codec_type") == "audio"), None
+ )
+
+ # Build base metadata
+ metadata = {
+ "duration": float(format_info.get("duration", 0)),
+ "size_bytes": int(format_info.get("size", 0)),
+ "bit_rate": (
+ int(format_info.get("bit_rate", 0))
+ if "bit_rate" in format_info
+ else None
+ ),
+ "format": format_info.get("format_name"),
+ "has_video": video_stream is not None,
+ "has_audio": audio_stream is not None,
+ }
+
+ # Add video stream details if present
+ if video_stream:
+ fps, fps_corrected, original_fps = _get_sanitized_fps(video_stream, max_fps)
+
+ metadata.update(
+ {
+ "fps": fps,
+ "width": int(video_stream.get("width", 0)),
+ "height": int(video_stream.get("height", 0)),
+ "codec": video_stream.get("codec_name"),
+ "pix_fmt": video_stream.get("pix_fmt"),
+ }
+ )
+
+ if fps_corrected:
+ metadata["original_fps"] = original_fps
+ metadata["fps_corrected"] = True
+
+ # Add audio stream details if present
+ if audio_stream:
+ metadata.update(
+ {
+ "audio_codec": audio_stream.get("codec_name"),
+ "sample_rate": audio_stream.get("sample_rate"),
+ "channels": int(audio_stream.get("channels", 0)),
+ }
+ )
+
+ return metadata
+
+ except subprocess.CalledProcessError as e:
+ return _create_error_metadata(f"ffprobe process failed: {e.stderr.strip()}")
+ except json.JSONDecodeError:
+ return _create_error_metadata("Failed to parse ffprobe output as JSON")
+ except Exception as e:
+ return _create_error_metadata(f"Unexpected error: {str(e)}")
+
+
+def _get_sanitized_fps(
+ video_stream: Dict[str, Any], max_fps: float = 60.0
+) -> Tuple[float, bool, Optional[float]]:
+ """Parse and sanitize frame rate from video stream information.
+
+ Returns:
+ Tuple of (sanitized_fps, was_corrected, original_fps_if_corrected)
+ """
+ original_fps = None
+ fps_corrected = False
+
+ # Try r_frame_rate first (usually more accurate)
+ fps = _parse_frame_rate_string(video_stream.get("r_frame_rate"))
+
+ # Fall back to avg_frame_rate if needed
+ if fps is None:
+ fps = _parse_frame_rate_string(video_stream.get("avg_frame_rate"))
+
+ # Save original before correction
+ if fps is not None:
+ original_fps = fps
+
+ # Sanity check and correct if needed
+ if fps is None or not (0 < fps <= max_fps) or not math.isfinite(fps):
+ fps_corrected = True
+ fps = 30.0 # Default to a standard frame rate
+
+ return fps, fps_corrected, original_fps if fps_corrected else None
+
+
+def _parse_frame_rate_string(frame_rate_str: Optional[str]) -> Optional[float]:
+ """Safely parse a frame rate string in format 'num/den'."""
+ if not frame_rate_str:
+ return None
+
+ try:
+ if "/" in frame_rate_str:
+ num, den = frame_rate_str.split("/")
+ num, den = float(num), float(den)
+ if den <= 0: # Avoid division by zero
+ return None
+ return num / den
+ else:
+ # Handle case where frame rate is just a number
+ return float(frame_rate_str)
+ except (ValueError, ZeroDivisionError):
+ return None
+
+
+def _create_error_metadata(error_message: str) -> Dict[str, Any]:
+ """Create a metadata dictionary for error cases."""
+ return {
+ "duration": 0,
+ "has_video": False,
+ "has_audio": False,
+ "error": error_message,
+ }
+
+
+def seconds_to_str(seconds):
+ """Convert seconds to formatted time string (HH:MM:SS)."""
+ seconds = int(float(seconds))
+ hours = seconds // 3600
+ minutes = (seconds % 3600) // 60
+ seconds = seconds % 60
+ return f"{hours:02}:{minutes:02}:{seconds:02}"
diff --git a/bitmind/config.py b/bitmind/config.py
new file mode 100644
index 00000000..3ce9d2f4
--- /dev/null
+++ b/bitmind/config.py
@@ -0,0 +1,375 @@
+import os
+import bittensor as bt
+
+MAINNET_UID = 34
+
+
+def validate_config_and_neuron_path(config):
+ r"""Checks/validates the config namespace object."""
+ full_path = os.path.expanduser(
+ "{}/{}/{}/netuid{}/{}".format(
+ config.logging.logging_dir,
+ config.wallet.name,
+ config.wallet.hotkey,
+ config.netuid,
+ config.neuron.name,
+ )
+ )
+ bt.logging.info(f"Logging path: {full_path}")
+ config.neuron.full_path = os.path.expanduser(full_path)
+ if not os.path.exists(config.neuron.full_path):
+ os.makedirs(config.neuron.full_path, exist_ok=True)
+ return config
+
+
+def add_args(parser):
+ """
+ Adds relevant arguments to the parser for operation.
+ """
+ parser.add_argument("--netuid", type=int, help="Subnet netuid", default=34)
+
+ parser.add_argument(
+ "--neuron.name",
+ type=str,
+ help="Neuron Name",
+ default="bitmind",
+ )
+
+ parser.add_argument(
+ "--epoch-length",
+ type=int,
+ help="The default epoch length (how often we set weights, measured in 12 second blocks).",
+ default=360,
+ )
+
+ parser.add_argument(
+ "--mock",
+ action="store_true",
+ help="Run in mock mode",
+ default=False,
+ )
+
+ parser.add_argument(
+ "--autoupdate-off",
+ action="store_false",
+ dest="autoupdate",
+ help="Disable automatic updates on latest version on Main.",
+ default=True,
+ )
+
+ parser.add_argument("--wandb.entity", type=str, default="bitmindai")
+
+ parser.add_argument("--wandb.off", action="store_true", default=False)
+
+
+def add_miner_args(parser):
+ """Add miner specific arguments to the parser."""
+
+ parser.add_argument(
+ "--no-force-validator-permit",
+ action="store_true",
+ help="If set, we will not force incoming requests to have a permit.",
+ default=False,
+ )
+
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cpu",
+ help="Device to use for detection models (cuda/cpu)",
+ )
+
+
+def add_validator_args(parser):
+ """Add validator specific arguments to the parser."""
+
+ parser.add_argument(
+ "--vpermit-tao-limit",
+ type=int,
+ help="The maximum number of TAO allowed to query a validator with a vpermit.",
+ default=20000,
+ )
+
+ parser.add_argument(
+ "--compressed-cache-update-interval",
+ type=int,
+ help="How often to download new zip/parquet files, measured in 12 second blocks",
+ default=720,
+ )
+
+ parser.add_argument(
+ "--media-cache-update-interval",
+ type=int,
+ help="How often to unpack random media files, measured in 12 second blocks",
+ default=60,
+ )
+
+ parser.add_argument(
+ "--challenge-interval",
+ type=int,
+ help="How often we set challenge miners, measured in 12 second blocks.",
+ default=5,
+ )
+
+ parser.add_argument(
+ "--wandb-restart-interval",
+ type=int,
+ help="How often we restart wandb run to avoid log truncation",
+ default=2000,
+ )
+
+ parser.add_argument(
+ "--cache.base-dir",
+ type=str,
+ default=os.path.expanduser("~/.cache/sn34"),
+ help="Base directory for cache storage",
+ )
+
+ parser.add_argument(
+ "--cache.max-compressed-gb",
+ type=float,
+ default=50.0,
+ help="Maximum size in GB for compressed cache",
+ )
+
+ parser.add_argument(
+ "--cache.max-media-gb",
+ type=float,
+ default=5.0,
+ help="Maximum size in GB for media cache",
+ )
+
+ parser.add_argument(
+ "--cache.media-files-per-source",
+ type=int,
+ default=50,
+ help="Number of media files to keep per source",
+ )
+
+ parser.add_argument(
+ "--neuron.max-state-backup-hours",
+ type=float,
+ help="The oldest backup of validator state to load in the case of a failure to load most recent",
+ default=1,
+ )
+
+ parser.add_argument(
+ "--neuron.miner-total-timeout",
+ type=float,
+ help="Total timeout for miner requests in seconds",
+ default=9.0,
+ )
+
+ parser.add_argument(
+ "--neuron.miner-connect-timeout",
+ type=float,
+ help="TCP connection timeout for miner requests in seconds",
+ default=4.0,
+ )
+
+ parser.add_argument(
+ "--neuron.miner-sock-connect-timeout",
+ type=float,
+ help="Socket connection timeout for miner requests in seconds",
+ default=3.0,
+ )
+
+ parser.add_argument(
+ "--neuron.heartbeat",
+ action="store_true",
+ help="Run validator heartbeat thread",
+ default=False,
+ )
+
+ parser.add_argument(
+ "--neuron.heartbeat-interval-seconds",
+ type=float,
+ help="Interval between heartbeat checks in seconds",
+ default=60.0,
+ )
+
+ parser.add_argument(
+ "--neuron.lock-sleep-seconds",
+ type=float,
+ help="Sleep duration when lock is held in seconds",
+ default=5.0,
+ )
+
+ parser.add_argument(
+ "--neuron.max-stuck-count",
+ type=int,
+ help="Number of consecutive heartbeats with no progress before restart",
+ default=5,
+ )
+
+ parser.add_argument(
+ "--neuron.sample-size",
+ type=int,
+ help="Number of miners to query per challenge",
+ default=50,
+ )
+
+ parser.add_argument(
+ "--scoring.moving-average-alpha",
+ type=float,
+ help="Alpha for miner score EMA",
+ default=0.05,
+ )
+
+ parser.add_argument(
+ "--scoring.image-weight",
+ type=float,
+ help="Weight for image modality scoring",
+ default=0.6,
+ )
+
+ parser.add_argument(
+ "--scoring.video-weight",
+ type=float,
+ help="Weight for video modality scoring",
+ default=0.4,
+ )
+
+ parser.add_argument(
+ "--scoring.binary-weight",
+ type=float,
+ help="Weight for binary classification scoring",
+ default=0.75,
+ )
+
+ parser.add_argument(
+ "--scoring.multiclass-weight",
+ type=float,
+ help="Weight for multiclass classification scoring",
+ default=0.25,
+ )
+
+ parser.add_argument(
+ "--challenge.image-prob",
+ type=float,
+ help="Probability of selecting image modality for challenges",
+ default=0.5,
+ )
+
+ parser.add_argument(
+ "--challenge.video-prob",
+ type=float,
+ help="Probability of selecting video modality for challenges",
+ default=0.5,
+ )
+
+ parser.add_argument(
+ "--challenge.real-prob",
+ type=float,
+ help="Probability of selecting real media for challenges",
+ default=0.5,
+ )
+
+ parser.add_argument(
+ "--challenge.synthetic-prob",
+ type=float,
+ help="Probability of selecting synthetic media for challenges",
+ default=0.3,
+ )
+
+ parser.add_argument(
+ "--challenge.semisynthetic-prob",
+ type=float,
+ help="Probability of selecting semisynthetic media for challenges",
+ default=0.2,
+ )
+
+ parser.add_argument(
+ "--challenge.multi-video-prob",
+ type=float,
+ help="Probability of stitching together two videos of the same media type",
+ default=0.2,
+ )
+
+ parser.add_argument(
+ "--challenge.min-clip-duration",
+ type=float,
+ help="Minimum video clip duration in seconds",
+ default=1.0,
+ )
+
+ parser.add_argument(
+ "--challenge.max-clip-duration",
+ type=float,
+ help="Maximum video clip duration in seconds",
+ default=6.0,
+ )
+
+
+def add_data_generator_args(parser):
+ parser.add_argument(
+ "--cache-dir",
+ type=str,
+ default=os.path.expanduser("~/.cache/sn34"),
+ help="Directory for caching data",
+ )
+
+ parser.add_argument(
+ "--batch-size", type=int, default=3, help="Batch size for generation"
+ )
+
+ parser.add_argument(
+ "--tasks",
+ nargs="+",
+ choices=["t2v", "t2i", "i2i", "i2v"],
+ default=["t2v", "t2i", "i2i", "i2v"],
+ help="List of tasks to run (t2v, t2i, i2i, i2v). Defaults to all.",
+ )
+
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to use for generation (cuda/cpu)",
+ )
+
+ parser.add_argument(
+ "--wandb.num-batches-per-run",
+ type=int,
+ default=50,
+ help="Number of batches to generate before starting new W&B run (avoids log truncation)",
+ )
+
+ parser.add_argument("--wandb.process-name", type=str, default="generator")
+
+
+def add_proxy_args(parser):
+ parser.add_argument(
+ "--proxy.sample-size",
+ type=int,
+ default=50,
+ help="Number of miners to query for organics",
+ )
+
+ parser.add_argument(
+ "--proxy.client-url",
+ type=str,
+ default="https://subnet-api.bitmindlabs.ai",
+ help="URL for the proxy client authentication service",
+ )
+
+ parser.add_argument(
+ "--proxy.host",
+ type=str,
+ default="0.0.0.0",
+ help="Network interface to listen on",
+ )
+
+ parser.add_argument(
+ "--proxy.port",
+ type=int,
+ default=10913,
+ help="Port for the proxy server",
+ )
+
+ parser.add_argument(
+ "--proxy.external_port",
+ type=int,
+ default=10913,
+ help="Port for the proxy server",
+ )
diff --git a/bitmind/dataset_processing/dlib_tools/README.md b/bitmind/dataset_processing/dlib_tools/README.md
deleted file mode 100644
index 117a286e..00000000
--- a/bitmind/dataset_processing/dlib_tools/README.md
+++ /dev/null
@@ -1,3 +0,0 @@
-Our shape predictor for dlib was sourced from the following Hugging Face repository:
-
-https://huggingface.co/spaces/liangtian/birthdayCrown/blob/e96083d163606933a2cc74be372895f3cc5d1b96/shape_predictor_81_face_landmarks.dat
\ No newline at end of file
diff --git a/bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat b/bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat
deleted file mode 100644
index 99036323..00000000
Binary files a/bitmind/dataset_processing/dlib_tools/shape_predictor_81_face_landmarks.dat and /dev/null differ
diff --git a/bitmind/encoding.py b/bitmind/encoding.py
new file mode 100644
index 00000000..ea1e0054
--- /dev/null
+++ b/bitmind/encoding.py
@@ -0,0 +1,160 @@
+import numpy as np
+import cv2
+import ffmpeg
+import os
+import tempfile
+from typing import List
+from io import BytesIO
+from PIL import Image
+
+
+def image_to_bytes(img):
+ """Convert image array to bytes using JPEG encoding with PIL.
+ Args:
+ img (np.ndarray): Image array of shape (C, H, W) or (H, W, C)
+ Can be float32 [0,1] or uint8 [0,255]
+ Returns:
+ bytes: JPEG encoded image bytes
+ str: Content type 'image/jpeg'
+ """
+ # Convert float32 [0,1] to uint8 [0,255] if needed
+ if img.dtype == np.float32:
+ img = (img * 255).astype(np.uint8)
+ elif img.dtype != np.uint8:
+ raise ValueError(f"Image must be float32 or uint8, got {img.dtype}")
+
+ if img.shape[0] == 3 and len(img.shape) == 3: # If in CHW format
+ img = np.transpose(img, (1, 2, 0)) # CHW to HWC
+
+ # Ensure we have a 3-channel image (H,W,3)
+ if len(img.shape) == 2:
+ # Convert grayscale to RGB
+ img = np.stack([img, img, img], axis=2)
+ elif img.shape[2] == 1:
+ # Convert single channel to RGB
+ img = np.concatenate([img, img, img], axis=2)
+ elif img.shape[2] == 4:
+ # Drop alpha channel
+ img = img[:, :, :3]
+ elif img.shape[2] != 3:
+ raise ValueError(f"Expected 1, 3 or 4 channels, got {img.shape[2]}")
+
+ pil_img = Image.fromarray(img)
+ if pil_img.mode != "RGB":
+ pil_img = pil_img.convert("RGB")
+
+ buffer = BytesIO()
+ pil_img.save(buffer, format="JPEG", quality=75)
+ buffer.seek(0)
+
+ return buffer.getvalue(), "image/jpeg"
+
+
+def video_to_bytes(video: np.ndarray, fps: int | None = None) -> tuple[bytes, str]:
+ """
+ Convert a (T, H, W, C) uint8/float32 video to MP4, but *first* pass each frame
+ through Pillow JPEG → adds normal JPEG artefacts, then encodes losslessly.
+
+ Returns:
+ bytes: In‑memory MP4 file.
+ str: MIME‑type ("video/mp4").
+ """
+ # ------------- 0. validation / normalisation -------------------------------
+ if video.dtype == np.float32:
+ assert video.max() <= 1.0, video.max()
+ video = (video * 255).clip(0, 255).astype(np.uint8)
+ elif video.dtype != np.uint8:
+ raise ValueError(f"Unsupported dtype: {video.dtype}")
+
+ fps = fps or 30
+
+ # TCHW → THWC
+ if video.shape[1] <= 4 and video.shape[3] > 4:
+ video = np.transpose(video, (0, 2, 3, 1))
+
+ if video.ndim != 4 or video.shape[3] not in (1, 3):
+ raise ValueError(f"Expected shape (T, H, W, C), got {video.shape}")
+
+ T, H, W, C = video.shape
+
+ # ------------- 1. apply Pillow JPEG to every frame -------------------------
+ jpeg_degraded_frames: List[np.ndarray] = []
+ for idx, frame in enumerate(video):
+ buf = BytesIO()
+ Image.fromarray(frame).save(
+ buf,
+ format="JPEG",
+ quality=75,
+ subsampling=2, # 0=4:4:4, 1=4:2:2, 2=4:2:0 (Pillow default = 2)
+ optimize=False,
+ progressive=False,
+ )
+ buf.seek(0)
+ # decode back to RGB so FFmpeg sees the artefact‑laden pixels
+ degraded = np.array(Image.open(buf).convert("RGB"), dtype=np.uint8)
+ if degraded.shape != (H, W, 3):
+ raise ValueError(f"Decoded shape mismatch at frame {idx}: {degraded.shape}")
+ jpeg_degraded_frames.append(degraded)
+
+ degraded_video = np.stack(jpeg_degraded_frames, axis=0) # (T,H,W,3)
+
+ # ------------- 2. write raw RGB + encode losslessly ------------------------
+ with tempfile.TemporaryDirectory() as tmpdir:
+ raw_path = os.path.join(tmpdir, "input.raw")
+ video_path = os.path.join(tmpdir, "output.mp4")
+
+ degraded_video.tofile(raw_path) # write as one big rawvideo blob
+
+ try:
+ (
+ ffmpeg.input(
+ raw_path,
+ format="rawvideo",
+ pix_fmt="rgb24",
+ s=f"{W}x{H}",
+ r=fps,
+ )
+ .output(
+ video_path,
+ vcodec="libx264rgb",
+ crf=0, # mathematically lossless
+ preset="veryfast",
+ pix_fmt="rgb24",
+ movflags="+faststart",
+ )
+ .global_args("-y", "-hide_banner", "-loglevel", "error")
+ .run()
+ )
+ except ffmpeg.Error as e:
+ raise RuntimeError(
+ f"FFmpeg encoding failed:\n{e.stderr.decode(errors='ignore')}"
+ ) from e
+
+ with open(video_path, "rb") as f:
+ video_bytes = f.read()
+
+ return video_bytes, "video/mp4"
+
+
+def media_to_bytes(media, fps=30):
+ """Convert image or video array to bytes, using PNG encoding for both.
+
+ Args:
+ media (np.ndarray): Either:
+ - Image array of shape (C, H, W)
+ - Video array of shape (T, C, H, W)
+ Can be float32 [0,1] or uint8 [0,255]
+ fps (int): Frames per second for video (default: 30)
+
+ Returns:
+ bytes: Encoded media bytes
+ str: Content type (either 'image/png' or 'video/avi')
+ """
+ if len(media.shape) == 3: # Image
+ return image_to_bytes(media)
+ elif len(media.shape) == 4: # Video
+ return video_to_bytes(media, fps)
+ else:
+ raise ValueError(
+ f"Invalid media shape: {media.shape}. Expected (C,H,W) for image or (T,C,H,W) for video."
+ )
diff --git a/bitmind/epistula.py b/bitmind/epistula.py
new file mode 100644
index 00000000..7354d93d
--- /dev/null
+++ b/bitmind/epistula.py
@@ -0,0 +1,186 @@
+import json
+from hashlib import sha256
+from uuid import uuid4
+from math import ceil
+from typing import Annotated, Any, Dict, Optional
+
+import bittensor as bt
+import numpy as np
+import asyncio
+import ast
+import time
+import httpx
+import aiohttp
+from substrateinterface import Keypair
+
+from bitmind.types import Modality
+
+
+EPISTULA_VERSION = str(2)
+
+
+def generate_header(
+ hotkey: Keypair,
+ body: Any,
+ signed_for: Optional[str] = None,
+) -> Dict[str, Any]:
+ timestamp = round(time.time() * 1000)
+ timestampInterval = ceil(timestamp / 1e4) * 1e4
+ uuid = str(uuid4())
+ req_hash = None
+ if isinstance(body, bytes):
+ req_hash = sha256(body).hexdigest()
+ else:
+ req_hash = sha256(json.dumps(body).encode("utf-8")).hexdigest()
+
+ headers = {
+ "Epistula-Version": EPISTULA_VERSION,
+ "Epistula-Timestamp": str(timestamp),
+ "Epistula-Uuid": uuid,
+ "Epistula-Signed-By": hotkey.ss58_address,
+ "Epistula-Request-Signature": "0x"
+ + hotkey.sign(f"{req_hash}.{uuid}.{timestamp}.{signed_for or ''}").hex(),
+ }
+ if signed_for:
+ headers["Epistula-Signed-For"] = signed_for
+ headers["Epistula-Secret-Signature-0"] = (
+ "0x" + hotkey.sign(str(timestampInterval - 1) + "." + signed_for).hex()
+ )
+ headers["Epistula-Secret-Signature-1"] = (
+ "0x" + hotkey.sign(str(timestampInterval) + "." + signed_for).hex()
+ )
+ headers["Epistula-Secret-Signature-2"] = (
+ "0x" + hotkey.sign(str(timestampInterval + 1) + "." + signed_for).hex()
+ )
+ return headers
+
+
+def verify_signature(
+ signature, body: bytes, timestamp, uuid, signed_for, signed_by, now
+) -> Optional[Annotated[str, "Error Message"]]:
+ if not isinstance(signature, str):
+ return "Invalid Signature"
+ timestamp = int(timestamp)
+ if not isinstance(timestamp, int):
+ return "Invalid Timestamp"
+ if not isinstance(signed_by, str):
+ return "Invalid Sender key"
+ if not isinstance(signed_for, str):
+ return "Invalid receiver key"
+ if not isinstance(uuid, str):
+ return "Invalid uuid"
+ if not isinstance(body, bytes):
+ return "Body is not of type bytes"
+ ALLOWED_DELTA_MS = 8000
+ keypair = Keypair(ss58_address=signed_by)
+ if timestamp + ALLOWED_DELTA_MS < now:
+ return "Request is too stale"
+ message = f"{sha256(body).hexdigest()}.{uuid}.{timestamp}.{signed_for}"
+ verified = keypair.verify(message, signature)
+ if not verified:
+ return "Signature Mismatch"
+ return None
+
+
+def create_header_hook(hotkey, axon_hotkey, model):
+ async def add_headers(request: httpx.Request):
+ for key, header in generate_header(hotkey, request.read(), axon_hotkey).items():
+ request.headers[key] = header
+
+ return add_headers
+
+
+async def query_miner(
+ uid: int,
+ media: bytes,
+ content_type: str,
+ modality: Modality,
+ axon_info: bt.AxonInfo,
+ session: aiohttp.ClientSession,
+ hotkey: bt.Keypair,
+ total_timeout: float,
+ connect_timeout: Optional[float] = None,
+ sock_connect_timeout: Optional[float] = None,
+) -> Dict[str, Any]:
+ """
+ Query a miner with media data.
+
+ Args:
+ uid: miner uid
+ media: encoded media
+ content_type: determined by media_to_bytes
+ modality: Type of media ('image' or 'video')
+ axon_info: miner AxonInfo
+ session: aiohttp client session
+ hotkey: validator hotkey Keypair for signing the request
+ total_timeout: Total timeout for the request
+ connect_timeout: Connection timeout
+ sock_connect_timeout: Socket connection timeout
+
+ Returns:
+ Dictionary containing the response
+ """
+ response = {
+ "uid": uid,
+ "hotkey": axon_info.hotkey,
+ "status": 500,
+ "prediction": None,
+ "error": "",
+ }
+
+ try:
+
+ headers = generate_header(hotkey, media, axon_info.hotkey)
+ url = f"http://{axon_info.ip}:{axon_info.port}/detect_{modality}"
+
+ async with session.post(
+ url,
+ headers={
+ "Content-Type": content_type,
+ "X-Media-Type": modality,
+ **headers,
+ },
+ data=media,
+ timeout=aiohttp.ClientTimeout(
+ total=total_timeout,
+ connect=connect_timeout,
+ sock_connect=sock_connect_timeout,
+ ),
+ ) as res:
+ response["status"] = res.status
+ if res.status != 200:
+ response["error"] = f"HTTP {res.status} error"
+ return response
+ try:
+ data = await res.json()
+ if "prediction" not in data:
+ response["error"] = "Missing prediction in response"
+ return response
+
+ pred = [float(p) for p in data["prediction"]]
+ response["prediction"] = np.array(pred)
+ return response
+
+ except json.JSONDecodeError:
+ response["error"] = "Failed to decode JSON response"
+ return response
+
+ except (TypeError, ValueError):
+ response["error"] = (
+ f"Invalid prediction value: {data.get('prediction')}"
+ )
+ return response
+
+ except asyncio.TimeoutError:
+ response["status"] = 408
+ response["error"] = "Request timed out"
+ except aiohttp.ClientConnectorError as e:
+ response["status"] = 503
+ response["error"] = f"Connection error: {str(e)}"
+ except aiohttp.ClientError as e:
+ response["status"] = 400
+ response["error"] = f"Network error: {str(e)}"
+ except Exception as e:
+ response["error"] = f"Unknown error: {str(e)}"
+
+ return response
diff --git a/bitmind/generation/__init__.py b/bitmind/generation/__init__.py
new file mode 100644
index 00000000..4e309c59
--- /dev/null
+++ b/bitmind/generation/__init__.py
@@ -0,0 +1,3 @@
+from .generation_pipeline import GenerationPipeline
+from .prompt_generator import PromptGenerator
+from .models import initialize_model_registry
diff --git a/bitmind/generation/generation_pipeline.py b/bitmind/generation/generation_pipeline.py
new file mode 100644
index 00000000..184abdb8
--- /dev/null
+++ b/bitmind/generation/generation_pipeline.py
@@ -0,0 +1,546 @@
+import gc
+import json
+import random
+import time
+import asyncio
+from pathlib import Path
+from typing import Dict, Optional, Any, Union, List
+import traceback
+
+import bittensor as bt
+import numpy as np
+import torch
+from diffusers.utils import export_to_video
+from PIL import Image
+
+from bitmind.types import CacheConfig, ModelTask
+from bitmind.generation.util.image import create_random_mask, is_black_output
+from bitmind.generation.util.prompt import truncate_prompt_if_too_long
+from bitmind.generation.prompt_generator import PromptGenerator
+from bitmind.generation.util.model import (
+ create_pipeline_generator,
+ enable_model_optimizations,
+)
+from bitmind.generation.model_registry import ModelRegistry
+from bitmind.generation.models import initialize_model_registry
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+torch.set_float32_matmul_precision("high")
+
+IMAGE_ANNOTATION_MODEL: str = "Salesforce/blip2-opt-6.7b-coco"
+TEXT_MODERATION_MODEL: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
+
+
+class GenerationPipeline:
+ """
+ A class for generating synthetic images and videos.
+
+ This class supports different prompt generation strategies and can utilize
+ various text-to-video (t2v), text-to-image (t2i), and image-to-image (i2i) models.
+
+ Attributes:
+ model_name: Name of the specific model to use (if not random)
+ model_registry: Registry of available models
+ output_dir: Directory to write generated data
+ """
+
+ def __init__(
+ self,
+ output_dir: Optional[Union[str, Path]] = None,
+ model_registry: Optional[ModelRegistry] = None,
+ device: str = "cuda",
+ ) -> None:
+ """
+ Initialize the SyntheticDataGenerator.
+
+ Args:
+ model_name: Name of the generative image/video model
+ output_dir: Directory to write generated data
+ model_registry: Optional ModelRegistry instance
+ device: Device identifier
+
+ Raises:
+ ValueError: If an invalid model name or configuration is provided
+ """
+ self.output_dir = Path(output_dir)
+ self.model_registry = model_registry or initialize_model_registry()
+ self.device = device
+ self.loop = asyncio.get_event_loop()
+
+ self.prompt_generator = PromptGenerator(
+ vlm_name=IMAGE_ANNOTATION_MODEL, llm_name=TEXT_MODERATION_MODEL
+ )
+
+ def generate(
+ self,
+ image_samples: List[dict],
+ tasks: Optional[Union[str, List[str]]] = None,
+ model_names: Optional[Union[str, List[str]]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Generate synthetic data based on input parameters.
+
+ Args:
+ image_samples: Image samples returned by ImageSampler
+ task: Optional task type.
+ model_name: Optional model name.
+
+ Returns:
+ Dictionary containing generated data information.
+
+ Raises:
+ ValueError: If image is None and cannot be sampled.
+ """
+ bt.logging.info(f"---------- Starting Generation ----------")
+ prompts = self.generate_prompts(image_samples, downstream_tasks=tasks, clear_gpu=True)
+ paths, stats = self.generate_media(prompts, model_names, image_samples, tasks)
+
+ def log_stats(stats):
+ model_names = list(stats.keys())
+ total_successes = sum([stats[name]["success"] for name in model_names])
+
+ if total_successes == 0:
+ log_fn = bt.logging.error
+ elif total_successes == len(image_samples) * len(model_names):
+ log_fn = bt.logging.success
+ else:
+ log_fn = bt.logging.warning
+
+ log_fn(json.dumps(stats, indent=2))
+
+ log_stats(stats)
+ bt.logging.info(f"---------- Generation Complete ----------")
+ return paths
+
+ def generate_prompts(
+ self,
+ image_samples: List[dict],
+ downstream_tasks: Optional[List[str]] = None,
+ clear_gpu: bool = True,
+ ) -> str:
+ """
+ Generate a prompts based on input images and downstream tasks.
+ """
+ if downstream_tasks is None:
+ downstream_tasks = [
+ ModelTask.TEXT_TO_IMAGE.value,
+ ModelTask.TEXT_TO_VIDEO.value,
+ ModelTask.IMAGE_TO_IMAGE.value,
+ ModelTask.IMAGE_TO_VIDEO.value,
+ ]
+
+ k = len(image_samples)
+ bt.logging.info(f"Generating {k} prompt{'s' if k > 1 else ''}")
+
+ self.prompt_generator.load_models()
+
+ # organize prompts in a dict to avoid failed prompt generations causing misaligned images/prompts
+ prompts = {task: {} for task in downstream_tasks}
+ for i in range(k):
+ image_path = image_samples[i].get('path')
+ image = image_samples[i].get('image')
+ for task in downstream_tasks:
+ try:
+ prompts[task][i] = self.prompt_generator.generate(
+ image, downstream_task=task
+ )
+
+ bt.logging.info(f"Generated prompt {i+1}/{k}: {prompts[task][i]} from {image_path}")
+ except Exception as e:
+ prompts[task][i] = None
+ bt.logging.error(f"Error generating prompt for image {i+1}: {e} from {image_path}")
+ bt.logging.error(traceback.format_exc())
+ continue
+
+ if clear_gpu:
+ self.prompt_generator.clear_gpu()
+
+ return prompts
+
+ def generate_media(
+ self,
+ prompts: Union[dict, str],
+ model_names: Optional[Union[str, List[str]]] = None,
+ image_samples: List[dict] = None,
+ tasks: Optional[Union[str, List[str]]] = None,
+ ) -> Dict[str, Any]:
+ """
+ Generate synthetic data based on a text prompt.
+
+ Args:
+ prompt: The text prompt used for generation, or a dictionary with
+ the outer key of generation task type, inner key of the image index,
+ and value of the prompt
+ task: The generation task type ('t2i', 't2v', 'i2i', or None).
+ model_name: Optional model name to use for generation.
+ image: Optional input image for image-to-image generation.
+
+ Returns:
+ Dictionary containing generated data and metadata.
+
+ Raises:
+ RuntimeError: If generation fails.
+ """
+ model_names = self._validate_model_names(model_names, tasks)
+
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ n_models = len(model_names)
+ n_prompts = len(prompts)
+
+ stats = {model_name: {"total": 0, "success": 0} for model_name in model_names}
+ save_paths = []
+
+ for model_idx, model_name in enumerate(model_names):
+ modality = self.model_registry.get_modality(model_name)
+ task = self.model_registry.get_task(model_name)
+
+ if isinstance(prompts, list):
+ task_prompts = {i: p for i, p in enumerate(prompts)}
+ else:
+ # task-specific prompts (motion enhancement for video)
+ task_prompts = prompts[task]
+
+ for prompt_idx in task_prompts:
+ stats[model_name]["total"] += 1
+ bt.logging.info(
+ f"Starting batch | Model {model_idx+1}/{n_models} | Prompt {prompt_idx+1}/{n_prompts}"
+ )
+ bt.logging.info(f" Model: {model_name}")
+ bt.logging.info(f" Prompt: {task_prompts[prompt_idx]}")
+
+ try:
+ image = None
+ if image_samples is not None and len(image_samples) > prompt_idx:
+ image = image_samples[prompt_idx].get('image')
+
+ # 3 retries for black (NSFW filtered) output
+ for _ in range(3):
+ gen_output = self._generate_media_with_model(
+ model_name, task_prompts[prompt_idx], image
+ )
+ if is_black_output(modality, gen_output):
+ # sanitize and retry
+ self.clear_gpu()
+ self.prompt_generator.load_llm()
+ task_prompts[prompt_idx] = self.prompt_generator.sanitize(
+ task_prompts[prompt_idx]
+ )
+ self.prompt_generator.clear_gpu()
+ else:
+ break
+
+ bt.logging.info(
+ {
+ k: v
+ for k, v in gen_output.items()
+ if k not in (modality, "source_image", "mask_image")
+ }
+ )
+ save_paths.append(self._save_media_and_metadata(gen_output))
+ stats[model_name]["success"] += 1
+ except Exception as e:
+ bt.logging.error(f"Failed to either generate or save media: {e}")
+ bt.logging.error(f" Model: {model_name}")
+ bt.logging.error(f" Prompt: {task_prompts[prompt_idx]}")
+ bt.logging.error(traceback.format_exc())
+
+ return save_paths, stats
+
+ def _load_model(
+ self,
+ model_name: Optional[str] = None,
+ ) -> None:
+ bt.logging.info(f"Loading {model_name}")
+ try:
+ model_config = self.model_registry.get_model_dict(model_name)
+ bt.logging.info(
+ json.dumps({k: str(v) for k, v in model_config.items()}, indent=2)
+ )
+
+ pipeline_cls = model_config["pipeline_cls"]
+ pipeline_args = model_config.get("from_pretrained_args", {}).copy()
+
+ # Handle custom loading functions passed as tuples
+ for k, v in pipeline_args.items():
+ if isinstance(v, tuple) and callable(v[0]):
+ pipeline_args[k] = v[0](**v[1])
+
+ model_id = pipeline_args.pop("model_id", model_name)
+
+ if isinstance(pipeline_cls, dict):
+ # Multi-stage pipeline
+ MODEL = {}
+ for stage_name, stage_cls in pipeline_cls.items():
+ stage_args = pipeline_args.get(stage_name, {})
+ base_model = stage_args.get("base", model_id)
+ stage_args_filtered = {
+ k: v for k, v in stage_args.items() if k != "base"
+ }
+
+ bt.logging.debug(f"Loading {stage_name} from {base_model}")
+ MODEL[stage_name] = stage_cls.from_pretrained(
+ base_model,
+ **stage_args_filtered,
+ add_watermarker=False,
+ )
+
+ enable_model_optimizations(
+ model=MODEL[stage_name],
+ device=self.device,
+ enable_cpu_offload=model_config.get(
+ "enable_model_cpu_offload", False
+ ),
+ enable_sequential_cpu_offload=model_config.get(
+ "enable_sequential_cpu_offload", False
+ ),
+ enable_vae_slicing=model_config.get(
+ "vae_enable_slicing", False
+ ),
+ enable_vae_tiling=model_config.get("vae_enable_tiling", False),
+ stage_name=stage_name,
+ )
+
+ MODEL[stage_name].watermarker = None
+ else:
+ # Single-stage pipeline
+ MODEL = pipeline_cls.from_pretrained(
+ model_id,
+ **pipeline_args,
+ add_watermarker=False,
+ )
+
+ # Load LoRA weights if specified
+ if "lora_model_id" in model_config:
+ bt.logging.info(
+ f"Loading LoRA weights from {model_config['lora_model_id']}"
+ )
+ lora_loading_args = model_config.get("lora_loading_args", {})
+ self.model.load_lora_weights(
+ model_config["lora_model_id"], **lora_loading_args
+ )
+
+ # Load scheduler if specified
+ scheduler_config = model_config.get("scheduler", {})
+ if scheduler_config:
+ sched_cls = scheduler_config["cls"]
+ sched_args = scheduler_config.get("from_config_args", {})
+ MODEL.scheduler = sched_cls.from_config(
+ MODEL.scheduler.config, **sched_args
+ )
+
+ enable_model_optimizations(
+ model=MODEL,
+ device=self.device,
+ enable_cpu_offload=model_config.get(
+ "enable_model_cpu_offload", False
+ ),
+ enable_sequential_cpu_offload=model_config.get(
+ "enable_sequential_cpu_offload", False
+ ),
+ enable_vae_slicing=model_config.get("vae_enable_slicing", False),
+ enable_vae_tiling=model_config.get("vae_enable_tiling", False),
+ )
+ MODEL.watermarker = None
+
+ self.model = MODEL
+ bt.logging.info(f"Loaded {model_name}")
+ return True
+
+ except Exception as e:
+ bt.logging.error(f"Error loading model: {model_name}")
+ bt.logging.error(traceback.format_exc())
+ return False
+
+ def _generate_media_with_model(self, model_name, prompt, image):
+ model_config = self.model_registry.get_model_dict(model_name)
+ task = self.model_registry.get_task(model_name)
+
+ if task == "i2i" and image is None:
+ raise ValueError(
+ "An image must be provided for image-to-image model {model_name}"
+ )
+
+ if not self._load_model(model_name):
+ raise RuntimeError(f"Failed to load {model_name}")
+ return {}
+
+ bt.logging.debug("Preparing generation arguments")
+ gen_args = model_config.get("generate_args", {}).copy()
+ mask_center = None
+
+ # prep inptask-specific generation args
+ if task == "i2i":
+ if image is None:
+ raise ValueError("image cannot be None for image-to-image generation")
+ image = Image.fromarray(image)
+ target_size = (1024, 1024)
+ if image.size[0] > target_size[0] or image.size[1] > target_size[1]:
+ image = image.resize(target_size, Image.Resampling.LANCZOS)
+
+ gen_args["mask_image"], mask_center = create_random_mask(image.size)
+ gen_args["image"] = image
+
+ elif task == "i2v":
+ if image is None:
+ raise ValueError("image cannot be None for image-to-video generation")
+ image = Image.fromarray(image)
+ # Get target size from gen_args if specified, otherwise use default
+ target_size = (gen_args.get("height", 768), gen_args.get("width", 768))
+ if image.size[0] > target_size[0] or image.size[1] > target_size[1]:
+ image = image.resize(target_size, Image.Resampling.LANCZOS)
+ gen_args["image"] = image
+
+ # Prepare generation arguments
+ for k, v in gen_args.items():
+ if isinstance(v, dict):
+ if "min" in v and "max" in v:
+ # For i2v, use minimum values to save memory
+ if task == "i2v":
+ gen_args[k] = v["min"]
+ else:
+ gen_args[k] = np.random.randint(v["min"], v["max"])
+
+ if "options" in v:
+ gen_args[k] = random.choice(v["options"])
+
+ if "resolution" in gen_args:
+ gen_args["height"] = gen_args["resolution"][0]
+ gen_args["width"] = gen_args["resolution"][1]
+ del gen_args["resolution"]
+
+ truncated_prompt = truncate_prompt_if_too_long(prompt, self.model)
+ bt.logging.debug(f"Generating media from prompt: {truncated_prompt}")
+ bt.logging.debug(f"Generation args: {gen_args}")
+
+ generate_fn = create_pipeline_generator(model_config, self.model)
+
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ start_time = time.time()
+
+ bt.logging.debug("Generating media")
+
+ if model_config.get("use_autocast", True):
+ pretrained_args = model_config.get("from_pretrained_args", {})
+ torch_dtype = pretrained_args.get("torch_dtype", torch.bfloat16)
+
+ with torch.autocast(self.device, torch_dtype, cache_enabled=False):
+ gen_output = generate_fn(truncated_prompt, **gen_args)
+ else:
+ gen_output = generate_fn(truncated_prompt, **gen_args)
+
+ gen_time = time.time() - start_time
+
+ hours = int(gen_time // 3600)
+ minutes = int((gen_time % 3600) // 60)
+ seconds = int(gen_time % 60)
+ bt.logging.info(
+ f"Finished generation in {hours:02d}:{minutes:02d}:{seconds:02d}"
+ )
+
+ modality = self.model_registry.get_modality(model_name)
+ media_type = self.model_registry.get_output_media_type(model_name)
+ output = {
+ modality: gen_output, # image or video
+ "modality": modality,
+ "media_type": media_type,
+ "prompt": truncated_prompt,
+ "model_name": model_name,
+ "time": time.time(),
+ "gen_duration": gen_time,
+ "mask_center": mask_center,
+ }
+ for k in ["num_inference_steps", "guidance_scale", "resolution"]:
+ output[k] = gen_args.get(k, "")
+
+ source_image = gen_args.get("image", None)
+ if source_image is not None:
+ output["source_image"] = source_image
+
+ mask_image = gen_args.get("mask_image", None)
+ if mask_image is not None:
+ output["mask_image"] = mask_image
+
+ del self.model
+ gc.collect()
+ torch.cuda.empty_cache()
+ return output
+
+ def _validate_model_names(self, model_names, tasks) -> str:
+ if model_names is None:
+ if tasks is None:
+ model_names = self.model_registry.get_interleaved_model_names()
+ else:
+ tasks = [tasks] if not isinstance(tasks, (list, tuple)) else tasks
+ model_names = self.model_registry.get_model_names_by_task(tasks)
+
+ elif isinstance(model_names, str):
+ model_names = [model_names]
+
+ invalid_models = [
+ name for name in model_names if name not in self.model_registry.model_names
+ ]
+ if invalid_models:
+ raise ValueError(
+ f"Invalid model names {invalid_models}. "
+ f"Options are {self.model_registry.model_names}"
+ )
+ return model_names
+
+ def _save_media_and_metadata(self, media_sample):
+ modality = media_sample["modality"]
+ media_type = media_sample["media_type"]
+ model_name = media_sample["model_name"]
+
+ ouptput_dir = (
+ CacheConfig(
+ base_dir=self.output_dir, modality=modality, media_type=media_type
+ ).get_path()
+ / model_name.split("/")[1]
+ )
+
+ ouptput_dir.mkdir(parents=True, exist_ok=True)
+ base_path = ouptput_dir / str(media_sample["time"])
+ bt.logging.debug(f"[{modality}:{media_type}] Writing to cache")
+
+ metadata = {
+ k: v
+ for k, v in media_sample.items()
+ if k not in (modality, "source_image", "mask_image")
+ }
+ base_path.with_suffix(".json").write_text(json.dumps(metadata))
+
+ if modality == "image":
+ save_path = str(base_path.with_suffix(".png"))
+ media_sample[modality].images[0].save(save_path)
+ elif modality == "video":
+ save_path = str(base_path.with_suffix(".mp4"))
+ export_to_video(media_sample[modality].frames[0], save_path, fps=30)
+
+ bt.logging.info(f"Wrote to {save_path}")
+ return save_path
+
+ def clear_gpu(self):
+ if hasattr(self, "model") and self.model is not None:
+ bt.logging.trace("Deleting model")
+ if isinstance(self.model, dict):
+ for stage_name, stage_model in self.model.items():
+ del stage_model
+ else:
+ del self.model
+
+ if hasattr(self, "prompt_generator"):
+ bt.logging.trace("Deleting prompt generator")
+ self.prompt_generator.clear_gpu()
+
+ gc.collect()
+ if torch.cuda.is_available():
+ bt.logging.trace("Clearing CUDA cache")
+ torch.cuda.empty_cache()
+
+ def shutdown(self):
+ self.clear_gpu()
diff --git a/bitmind/generation/model_registry.py b/bitmind/generation/model_registry.py
new file mode 100644
index 00000000..488ba755
--- /dev/null
+++ b/bitmind/generation/model_registry.py
@@ -0,0 +1,144 @@
+from typing import Optional, Dict, Union, Any, List
+import random
+
+from bitmind.types import ModelConfig, ModelTask
+
+
+class ModelRegistry:
+ """
+ Registry for managing generative models.
+ """
+
+ def __init__(self):
+ self.models: Dict[str, ModelConfig] = {}
+
+ def register(self, model_config: ModelConfig) -> None:
+ self.models[model_config.path] = model_config
+
+ def register_all(self, model_configs: List[ModelConfig]) -> None:
+ for config in model_configs:
+ self.register(config)
+
+ def get_model(self, path: str) -> Optional[ModelConfig]:
+ return self.models.get(path)
+
+ def get_all_models(self) -> Dict[str, ModelConfig]:
+ return self.models.copy()
+
+ def get_models_by_task(self, task: ModelTask) -> Dict[str, ModelConfig]:
+ return {
+ path: config for path, config in self.models.items() if config.task == task
+ }
+
+ def get_model_names_by_task(self, task: ModelTask) -> Dict[str, ModelConfig]:
+ return [path for path, config in self.models.items() if config.task == task]
+
+ def get_models_by_tag(self, tag: str) -> Dict[str, ModelConfig]:
+ return {
+ path: config for path, config in self.models.items() if tag in config.tags
+ }
+
+ def get_model_names_by_task(self, task: ModelTask) -> List[str]:
+ return list(self.get_models_by_task(task).keys())
+
+ @property
+ def t2i_models(self) -> Dict[str, ModelConfig]:
+ return self.get_models_by_task(ModelTask.TEXT_TO_IMAGE)
+
+ @property
+ def t2v_models(self) -> Dict[str, ModelConfig]:
+ return self.get_models_by_task(ModelTask.TEXT_TO_VIDEO)
+
+ @property
+ def i2i_models(self) -> Dict[str, ModelConfig]:
+ return self.get_models_by_task(ModelTask.IMAGE_TO_IMAGE)
+
+ @property
+ def i2v_models(self) -> List[str]:
+ return self.get_models_by_task(ModelTask.IMAGE_TO_VIDEO)
+
+ @property
+ def t2i_model_names(self) -> List[str]:
+ return list(self.t2i_models.keys())
+
+ @property
+ def t2v_model_names(self) -> List[str]:
+ return list(self.t2v_models.keys())
+
+ @property
+ def i2i_model_names(self) -> List[str]:
+ return list(self.i2i_models.keys())
+
+ @property
+ def i2v_model_names(self) -> List[str]:
+ return list(self.i2v_models.keys())
+
+ @property
+ def model_names(self) -> List[str]:
+ return list(self.models.keys())
+
+ def select_random_model(self, task: Optional[Union[ModelTask, str]] = None) -> str:
+ if isinstance(task, str):
+ task = ModelTask(task.lower())
+
+ if task is None:
+ task = random.choice(list(ModelTask))
+
+ model_names = self.get_model_names_by_task(task)
+ if not model_names:
+ raise ValueError(f"No models available for task: {task}")
+
+ return random.choice(model_names)
+
+ def get_model_dict(self, model_name: str) -> Dict[str, Any]:
+ model = self.get_model(model_name)
+ if model is None:
+ raise ValueError(f"Model not found: {model_name}")
+
+ return model.to_dict()
+
+ def get_interleaved_model_names(self, tasks=None) -> List[str]:
+ from itertools import zip_longest
+
+ model_names = []
+ if tasks is None:
+ model_names = [
+ self.t2i_model_names,
+ self.t2v_model_names,
+ self.i2i_model_names,
+ self.i2v_model_names,
+ ]
+ else:
+ for task in tasks:
+ model_names.append(self.get_model_names_by_task(task))
+
+ shuffled_model_names = (
+ random.sample(names, len(names)) for names in model_names
+ )
+ return [
+ m
+ for quad in zip_longest(*shuffled_model_names)
+ for m in quad
+ if m is not None
+ ]
+
+ def get_modality(self, model_name: str) -> str:
+ model = self.get_model(model_name)
+ if model is None:
+ raise ValueError(f"Model not found: {model_name}")
+
+ return "video" if model.task == ModelTask.TEXT_TO_VIDEO else "image"
+
+ def get_task(self, model_name: str) -> str:
+ model = self.get_model(model_name)
+ if model is None:
+ raise ValueError(f"Model not found: {model_name}")
+
+ return model.task.value
+
+ def get_output_media_type(self, model_name: str) -> str:
+ model = self.get_model(model_name)
+ if model is None:
+ raise ValueError(f"Model not found: {model_name}")
+
+ return model.media_type.value
diff --git a/bitmind/generation/models.py b/bitmind/generation/models.py
new file mode 100644
index 00000000..0bc0b7ad
--- /dev/null
+++ b/bitmind/generation/models.py
@@ -0,0 +1,407 @@
+from typing import List
+
+import torch
+from diffusers import (
+ StableDiffusionXLPipeline,
+ StableDiffusionInpaintPipeline,
+ FluxPipeline,
+ StableDiffusionPipeline,
+ DEISMultistepScheduler,
+ EulerDiscreteScheduler,
+ IFPipeline,
+ IFSuperResolutionPipeline,
+ HunyuanVideoPipeline,
+ MochiPipeline,
+ CogVideoXPipeline,
+ AnimateDiffPipeline,
+ AutoPipelineForInpainting,
+ CogView4Pipeline,
+ CogVideoXImageToVideoPipeline,
+)
+
+from bitmind.generation.model_registry import ModelRegistry
+from bitmind.generation.util.model import (
+ load_hunyuanvideo_transformer,
+ load_annimatediff_motion_adapter,
+ JanusWrapper,
+)
+from bitmind.types import ModelConfig, ModelTask
+
+
+def get_text_to_image_models() -> List[ModelConfig]:
+ """
+ Get the list of text-to-image models.
+
+ Returns:
+ List of text-to-image model configurations
+ """
+ return [
+ ModelConfig(
+ path="stabilityai/stable-diffusion-xl-base-1.0",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=StableDiffusionXLPipeline,
+ pretrained_args={
+ "use_safetensors": True,
+ "torch_dtype": torch.float16,
+ "variant": "fp16",
+ },
+ use_autocast=False,
+ tags=["stable-diffusion", "xl"],
+ ),
+ ModelConfig(
+ path="SG161222/RealVisXL_V4.0",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=StableDiffusionXLPipeline,
+ pretrained_args={
+ "use_safetensors": True,
+ "torch_dtype": torch.float16,
+ "variant": "fp16",
+ },
+ tags=["stable-diffusion", "xl", "realistic"],
+ ),
+ ModelConfig(
+ path="Corcelio/mobius",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=StableDiffusionXLPipeline,
+ pretrained_args={"use_safetensors": True, "torch_dtype": torch.float16},
+ tags=["stable-diffusion", "xl"],
+ ),
+ ModelConfig(
+ path="black-forest-labs/FLUX.1-dev",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=FluxPipeline,
+ pretrained_args={
+ "use_safetensors": True,
+ "torch_dtype": torch.bfloat16,
+ },
+ generate_args={
+ "guidance_scale": 2,
+ "num_inference_steps": {"min": 50, "max": 125},
+ "generator": torch.Generator(
+ "cuda" if torch.cuda.is_available() else "cpu"
+ ),
+ "resolution": [512, 768],
+ },
+ enable_model_cpu_offload=False,
+ tags=["flux"],
+ ),
+ ModelConfig(
+ path="prompthero/openjourney-v4",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=StableDiffusionPipeline,
+ pretrained_args={
+ "use_safetensors": True,
+ "torch_dtype": torch.float16,
+ },
+ tags=["stable-diffusion", "midjourney-style"],
+ ),
+ ModelConfig(
+ path="cagliostrolab/animagine-xl-3.1",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=StableDiffusionXLPipeline,
+ pretrained_args={
+ "use_safetensors": True,
+ "torch_dtype": torch.float16,
+ },
+ tags=["stable-diffusion", "xl", "anime"],
+ ),
+ ModelConfig(
+ path="DeepFloyd/IF",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls={"stage1": IFPipeline, "stage2": IFSuperResolutionPipeline},
+ pretrained_args={
+ "stage1": {
+ "base": "DeepFloyd/IF-I-XL-v1.0",
+ "torch_dtype": torch.float16,
+ "variant": "fp16",
+ "clean_caption": False,
+ "watermarker": None,
+ "requires_safety_checker": False,
+ },
+ "stage2": {
+ "base": "DeepFloyd/IF-II-L-v1.0",
+ "torch_dtype": torch.float16,
+ "variant": "fp16",
+ "text_encoder": None,
+ "watermarker": None,
+ "requires_safety_checker": False,
+ },
+ },
+ pipeline_stages=[
+ {
+ "name": "stage1",
+ "args": {
+ "output_type": "pt",
+ "num_images_per_prompt": 1,
+ "return_dict": True,
+ },
+ "output_attr": "images",
+ "output_transform": lambda x: x[0].unsqueeze(0),
+ "save_prompt_embeds": True,
+ },
+ {
+ "name": "stage2",
+ "input_key": "image",
+ "args": {"output_type": "pil", "num_images_per_prompt": 1},
+ "output_attr": "images",
+ "use_prompt_embeds": True,
+ },
+ ],
+ clear_memory_on_stage_end=True,
+ tags=["deepfloyd", "multi-stage"],
+ ),
+ ModelConfig(
+ path="deepseek-ai/Janus-Pro-7B",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=JanusWrapper,
+ pretrained_args={
+ "torch_dtype": torch.bfloat16,
+ "use_safetensors": True,
+ },
+ generate_args={
+ "temperature": 1.0,
+ "parallel_size": 4,
+ "cfg_weight": 5.0,
+ "image_token_num_per_image": 576,
+ "img_size": 384,
+ "patch_size": 16,
+ },
+ use_autocast=False,
+ enable_model_cpu_offload=False,
+ tags=["llm-based", "multimodal"],
+ ),
+ ModelConfig(
+ path="runwayml/stable-diffusion-v1-5-midjourney-v6",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=StableDiffusionPipeline,
+ pretrained_args={
+ "model_id": "runwayml/stable-diffusion-v1-5",
+ "torch_dtype": torch.float16,
+ "use_safetensors": True,
+ },
+ lora_model_id="Kvikontent/midjourney-v6",
+ lora_loading_args={"use_peft_backend": True},
+ use_autocast=False,
+ enable_model_cpu_offload=False,
+ tags=["stable-diffusion"],
+ ),
+ ModelConfig(
+ path="THUDM/CogView4-6B",
+ task=ModelTask.TEXT_TO_IMAGE,
+ pipeline_cls=CogView4Pipeline,
+ pretrained_args={
+ "torch_dtype": torch.bfloat16,
+ "use_safetensors": True,
+ },
+ generate_args={
+ "guidance_scale": 3.5,
+ "num_images_per_prompt": 1,
+ "num_inference_steps": 50,
+ "width": 512,
+ "height": 512,
+ },
+ use_autocast=False,
+ tags=[],
+ ),
+ ]
+
+
+def get_image_to_image_models() -> List[ModelConfig]:
+ """
+ Get the list of image-to-image models.
+
+ Returns:
+ List of image-to-image model configurations
+ """
+ return [
+ ModelConfig(
+ path="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
+ task=ModelTask.IMAGE_TO_IMAGE,
+ pipeline_cls=AutoPipelineForInpainting,
+ pretrained_args={
+ "use_safetensors": True,
+ "torch_dtype": torch.float16,
+ "variant": "fp16",
+ },
+ generate_args={
+ "guidance_scale": 7.5,
+ "num_inference_steps": 50,
+ "strength": 0.99,
+ "generator": torch.Generator(
+ "cuda" if torch.cuda.is_available() else "cpu"
+ )
+ },
+ tags=["stable-diffusion", "xl", "inpainting"],
+ ),
+ ModelConfig(
+ path="Lykon/dreamshaper-8-inpainting",
+ task=ModelTask.IMAGE_TO_IMAGE,
+ pipeline_cls=AutoPipelineForInpainting,
+ pretrained_args={"torch_dtype": torch.float16, "variant": "fp16"},
+ generate_args={
+ "num_inference_steps": {"min": 40, "max": 60},
+ },
+ scheduler={"cls": DEISMultistepScheduler},
+ tags=["stable-diffusion", "inpainting", "dreamshaper"],
+ ),
+ ]
+
+
+def get_text_to_video_models() -> List[ModelConfig]:
+ """
+ Get the list of text-to-video models.
+
+ Returns:
+ List of text-to-video model configurations
+ """
+ return [
+ ModelConfig(
+ path="tencent/HunyuanVideo",
+ task=ModelTask.TEXT_TO_VIDEO,
+ pipeline_cls=HunyuanVideoPipeline,
+ pretrained_args={
+ "model_id": "tencent/HunyuanVideo",
+ "transformer": ( # custom functions supplied as tuple of (fn, args)
+ load_hunyuanvideo_transformer,
+ {
+ "model_id": "tencent/HunyuanVideo",
+ "subfolder": "transformer",
+ "torch_dtype": torch.bfloat16,
+ "revision": "refs/pr/18",
+ },
+ ),
+ "revision": "refs/pr/18",
+ "torch_dtype": torch.bfloat16,
+ },
+ generate_args={
+ "num_frames": {"min": 61, "max": 129},
+ "resolution": {
+ "options": [
+ [720, 1280],
+ [1280, 720],
+ [1104, 832],
+ [832, 1104],
+ [960, 960],
+ [544, 960],
+ [960, 544],
+ [624, 832],
+ [832, 624],
+ [720, 720],
+ ]
+ },
+ "num_inference_steps": {"min": 30, "max": 50},
+ },
+ save_args={"fps": 30},
+ use_autocast=False,
+ vae_enable_tiling=True,
+ tags=["high-quality", "high-resolution"],
+ ),
+ ModelConfig(
+ path="genmo/mochi-1-preview",
+ task=ModelTask.TEXT_TO_VIDEO,
+ pipeline_cls=MochiPipeline,
+ pretrained_args={"variant": "bf16", "torch_dtype": torch.bfloat16},
+ generate_args={
+ "num_frames": 84,
+ "num_inference_steps": {"min": 30, "max": 65},
+ "resolution": [480, 848],
+ },
+ save_args={"fps": 30},
+ vae_enable_tiling=True,
+ tags=["mochi"],
+ ),
+ ModelConfig(
+ path="THUDM/CogVideoX-5b",
+ task=ModelTask.TEXT_TO_VIDEO,
+ pipeline_cls=CogVideoXPipeline,
+ pretrained_args={"use_safetensors": True, "torch_dtype": torch.bfloat16},
+ generate_args={
+ "guidance_scale": 2,
+ "num_videos_per_prompt": 1,
+ "num_inference_steps": {"min": 50, "max": 125},
+ "num_frames": 48,
+ },
+ save_args={"fps": 8},
+ enable_model_cpu_offload=True,
+ vae_enable_slicing=True,
+ vae_enable_tiling=True,
+ tags=["cogvideo"],
+ ),
+ ModelConfig(
+ path="ByteDance/AnimateDiff-Lightning",
+ task=ModelTask.TEXT_TO_VIDEO,
+ pipeline_cls=AnimateDiffPipeline,
+ pretrained_args={
+ "model_id": "emilianJR/epiCRealism",
+ "torch_dtype": torch.bfloat16,
+ "motion_adapter": (load_annimatediff_motion_adapter, {"step": 4}),
+ },
+ generate_args={
+ "guidance_scale": 2,
+ "num_inference_steps": {"min": 50, "max": 125},
+ "resolution": {
+ "options": [
+ [512, 512],
+ [512, 768],
+ [512, 1024],
+ [768, 512],
+ [768, 768],
+ [768, 1024],
+ [1024, 512],
+ [1024, 768],
+ [1024, 1024],
+ ]
+ },
+ },
+ save_args={"fps": 15},
+ scheduler={
+ "cls": EulerDiscreteScheduler,
+ "from_config_args": {
+ "timestep_spacing": "trailing",
+ "beta_schedule": "linear",
+ },
+ },
+ tags=["animate-diff", "motion-adapter"],
+ ),
+ ]
+
+
+def get_image_to_video_models() -> List[ModelConfig]:
+
+ return [
+ ModelConfig(
+ path="THUDM/CogVideoX1.5-5B-I2V",
+ task=ModelTask.IMAGE_TO_VIDEO,
+ pipeline_cls=CogVideoXImageToVideoPipeline,
+ pretrained_args={"use_safetensors": True, "torch_dtype": torch.bfloat16},
+ generate_args={
+ "guidance_scale": 2,
+ "num_videos_per_prompt": 1,
+ "num_inference_steps": {"min": 50, "max": 125},
+ "num_frames": 49,
+ "height": 768,
+ "width": 768,
+ },
+ save_args={"fps": 8},
+ enable_model_cpu_offload=True,
+ vae_enable_slicing=True,
+ vae_enable_tiling=True,
+ )
+ ]
+
+
+def initialize_model_registry() -> ModelRegistry:
+ """
+ Initialize and populate the model registry.
+
+ Returns:
+ Fully populated ModelRegistry instance
+ """
+ registry = ModelRegistry()
+
+ registry.register_all(get_text_to_image_models())
+ registry.register_all(get_image_to_image_models())
+ registry.register_all(get_text_to_video_models())
+ registry.register_all(get_image_to_video_models())
+
+ return registry
diff --git a/bitmind/synthetic_data_generation/prompt_generator.py b/bitmind/generation/prompt_generator.py
similarity index 64%
rename from bitmind/synthetic_data_generation/prompt_generator.py
rename to bitmind/generation/prompt_generator.py
index e418ef08..63a56576 100644
--- a/bitmind/synthetic_data_generation/prompt_generator.py
+++ b/bitmind/generation/prompt_generator.py
@@ -1,7 +1,6 @@
+import re
import gc
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
-
+import bittensor as bt
import torch
from PIL import Image
from transformers import (
@@ -10,14 +9,7 @@
Blip2ForConditionalGeneration,
Blip2Processor,
pipeline,
- logging as transformers_logging,
)
-from transformers.utils.logging import disable_progress_bar
-
-import bittensor as bt
-from bitmind.validator.config import HUGGINGFACE_CACHE_DIR
-
-disable_progress_bar()
class PromptGenerator:
@@ -33,7 +25,7 @@ def __init__(
self,
vlm_name: str,
llm_name: str,
- device: str = 'cuda',
+ device: str = "cuda",
) -> None:
"""
Initialize the ImageAnnotationGenerator with specific models and device settings.
@@ -50,81 +42,73 @@ def __init__(
self.llm_name = llm_name
self.vlm_processor = None
self.vlm = None
- self.llm_pipeline = None
+ self.llm = None
self.device = device
- def are_models_loaded(self) -> bool:
- return (self.vlm is not None) and (self.llm_pipeline is not None)
-
def load_vlm(self) -> None:
- bt.logging.info(f"Loading caption generation model {self.vlm_name}")
+ """
+ Load the vision-language model for image annotation.
+ """
+ bt.logging.debug(f"Loading caption generation model {self.vlm_name}")
self.vlm_processor = Blip2Processor.from_pretrained(
- self.vlm_name,
- cache_dir=HUGGINGFACE_CACHE_DIR
+ self.vlm_name, torch_dtype=torch.float32
)
self.vlm = Blip2ForConditionalGeneration.from_pretrained(
- self.vlm_name,
- torch_dtype=torch.float16,
- cache_dir=HUGGINGFACE_CACHE_DIR
+ self.vlm_name, torch_dtype=torch.float32
)
self.vlm.to(self.device)
bt.logging.info(f"Loaded image annotation model {self.vlm_name}")
-
+
def load_llm(self) -> None:
- bt.logging.info(f"Loading caption moderation model {self.llm_name}")
+ """
+ Load the language model for text moderation.
+ """
+ bt.logging.debug(f"Loading caption moderation model {self.llm_name}")
+ m = re.match(r"cuda:(\d+)", self.device)
+ gpu_id = int(m.group(1)) if m else 0
llm = AutoModelForCausalLM.from_pretrained(
- self.llm_name,
+ self.llm_name,
torch_dtype=torch.bfloat16,
- cache_dir=HUGGINGFACE_CACHE_DIR
- )
- tokenizer = AutoTokenizer.from_pretrained(
- self.llm_name,
- cache_dir=HUGGINGFACE_CACHE_DIR
- )
- llm = llm.to(self.device)
- self.llm_pipeline = pipeline(
- "text-generation",
- model=llm,
- tokenizer=tokenizer
+ device_map={"": gpu_id}
)
+ tokenizer = AutoTokenizer.from_pretrained(self.llm_name)
+ self.llm = pipeline("text-generation", model=llm, tokenizer=tokenizer)
bt.logging.info(f"Loaded caption moderation model {self.llm_name}")
-
+
def load_models(self) -> None:
"""
Load the necessary models for image annotation and text moderation onto
the specified device.
"""
- if self.are_models_loaded():
- bt.logging.warning(f"Models already loaded")
- return
- self.load_vlm()
- self.load_llm()
+ if self.vlm is None:
+ self.load_vlm()
+ else:
+ bt.logging.warning(f"vlm already loaded")
+
+ if self.llm is None:
+ self.load_llm()
+ else:
+ bt.logging.warning(f"llm already loaded")
def clear_gpu(self) -> None:
"""
Clear GPU memory by moving models back to CPU and deleting them,
followed by collecting garbage.
"""
- bt.logging.info("Clearing GPU memory after prompt generation")
+ bt.logging.debug("Clearing GPU memory after prompt generation")
if self.vlm:
- self.vlm.to('cpu')
del self.vlm
self.vlm = None
- if self.llm_pipeline:
- self.llm_pipeline.model.to('cpu')
- del self.llm_pipeline
- self.llm_pipeline = None
+ if self.llm:
+ del self.llm
+ self.llm = None
gc.collect()
torch.cuda.empty_cache()
def generate(
- self,
- image: Image.Image,
- task: Optional[str] = None,
- max_new_tokens: int = 20,
- verbose: bool = False
+ self, image: Image.Image, downstream_task: str = None, max_new_tokens: int = 20
) -> str:
"""
Generate a string description for a given image using prompt-based
@@ -136,64 +120,53 @@ def generate(
motion descriptions will be added.
max_new_tokens: The maximum number of tokens to generate for each
prompt.
- verbose: If True, additional logging information is printed.
Returns:
A generated description of the image.
"""
- if not verbose:
- transformers_logging.set_verbosity_error()
+ if self.vlm is None or self.vlm_processor is None:
+ self.load_vlm()
description = ""
prompts = [
"An image of",
"The setting is",
"The background is",
- "The image type/style is"
+ "The image type/style is",
]
for i, prompt in enumerate(prompts):
- description += prompt + ' '
+ description += prompt + " "
inputs = self.vlm_processor(
- image,
- text=description,
- return_tensors="pt"
- ).to(self.device, torch.float16)
-
- generated_ids = self.vlm.generate(
- **inputs,
- max_new_tokens=max_new_tokens
- )
+ image, text=description, return_tensors="pt"
+ ).to(self.device, torch.float32)
+
+ generated_ids = self.vlm.generate(**inputs, max_new_tokens=max_new_tokens)
answer = self.vlm_processor.batch_decode(
- generated_ids,
- skip_special_tokens=True
+ generated_ids, skip_special_tokens=True
)[0].strip()
- if verbose:
- bt.logging.info(f"{i}. Prompt: {prompt}")
- bt.logging.info(f"{i}. Answer: {answer}")
+ bt.logging.trace(f"{i}. Prompt: {prompt}")
+ bt.logging.trace(f"{i}. Answer: {answer}")
if answer:
answer = answer.rstrip(" ,;!?")
- if not answer.endswith('.'):
- answer += '.'
- description += answer + ' '
+ if not answer.endswith("."):
+ answer += "."
+ description += answer + " "
else:
- description = description[:-len(prompt) - 1]
-
- if not verbose:
- transformers_logging.set_verbosity_info()
+ description = description[: -len(prompt) - 1]
if description.startswith(prompts[0]):
- description = description[len(prompts[0]):]
+ description = description[len(prompts[0]) :]
description = description.strip()
- if not description.endswith('.'):
- description += '.'
+ if not description.endswith("."):
+ description += "."
moderated_description = self.moderate(description)
-
- if task in ['t2v', 'i2v']:
+
+ if downstream_task in ["t2v", "i2v"]:
return self.enhance(moderated_description)
return moderated_description
@@ -211,6 +184,9 @@ def moderate(self, description: str, max_new_tokens: int = 80) -> str:
The moderated description text, or the original description if
moderation fails.
"""
+ if self.llm is None:
+ self.load_llm()
+
messages = [
{
"role": "system",
@@ -219,21 +195,18 @@ def moderate(self, description: str, max_new_tokens: int = 80) -> str:
"eliminate redundancy, and remove all specific references to "
"individuals by name. You do not respond with anything other "
"than the revised description.[/INST]"
- )
+ ),
},
- {
- "role": "user",
- "content": description
- }
+ {"role": "user", "content": description},
]
try:
- moderated_text = self.llm_pipeline(
+ moderated_text = self.llm(
messages,
max_new_tokens=max_new_tokens,
- pad_token_id=self.llm_pipeline.tokenizer.eos_token_id,
- return_full_text=False
+ pad_token_id=self.llm.tokenizer.eos_token_id,
+ return_full_text=False,
)
- return moderated_text[0]['generated_text']
+ return moderated_text[0]["generated_text"]
except Exception as e:
bt.logging.error(f"An error occurred during moderation: {e}", exc_info=True)
@@ -243,14 +216,18 @@ def enhance(self, description: str, max_new_tokens: int = 80) -> str:
"""
Enhance a static image description to make it suitable for video generation
by adding dynamic elements and motion.
-
+
Args:
description: The static image description to enhance.
max_new_tokens: Maximum number of new tokens to generate in the enhanced text.
-
+
Returns:
- An enhanced description suitable for video generation.
+ An enhanced description suitable for video generation, or the original
+ description if enhancement fails.
"""
+ if self.llm is None:
+ self.load_llm()
+
messages = [
{
"role": "system",
@@ -266,31 +243,32 @@ def enhance(self, description: str, max_new_tokens: int = 80) -> str:
"3. Add ONE subtle camera motion that complements the scene\n"
"4. Keep the description concise and natural\n"
"Only respond with the enhanced description.[/INST]"
- )
+ ),
},
- {
- "role": "user",
- "content": description
- }
+ {"role": "user", "content": description},
]
try:
- enhanced_text = self.llm_pipeline(
+ enhanced_text = self.llm(
messages,
max_new_tokens=max_new_tokens,
- pad_token_id=self.llm_pipeline.tokenizer.eos_token_id,
- return_full_text=False
+ pad_token_id=self.llm.tokenizer.eos_token_id,
+ return_full_text=False,
)
- return enhanced_text[0]['generated_text']
+ return enhanced_text[0]["generated_text"]
except Exception as e:
bt.logging.error(f"An error occurred during motion enhancement: {e}")
return description
- def sanitize_prompt(self, prompt: str, max_new_tokens: int = 80) -> str:
+ def sanitize(self, prompt: str, max_new_tokens: int = 80) -> str:
"""
Use the LLM to make the prompt more SFW (less NSFW).
"""
+
+ if self.llm is None:
+ self.load_llm()
+
messages = [
{
"role": "system",
@@ -299,21 +277,18 @@ def sanitize_prompt(self, prompt: str, max_new_tokens: int = 80) -> str:
"Rephrase the following prompt to remove or neutralize any NSFW, sexual, or explicit content. "
"Keep the prompt as close as possible to the original intent, but ensure it is SFW. "
"Only respond with the sanitized prompt.[/INST]"
- )
+ ),
},
- {
- "role": "user",
- "content": prompt
- }
+ {"role": "user", "content": prompt},
]
try:
- sanitized = self.llm_pipeline(
+ sanitized = self.llm(
messages,
max_new_tokens=max_new_tokens,
- pad_token_id=self.llm_pipeline.tokenizer.eos_token_id,
- return_full_text=False
+ pad_token_id=self.llm.tokenizer.eos_token_id,
+ return_full_text=False,
)
- return sanitized[0]['generated_text']
+ return sanitized[0]["generated_text"]
except Exception as e:
bt.logging.error(f"An error occurred during prompt sanitization: {e}")
- return prompt
\ No newline at end of file
+ return prompt
diff --git a/base_miner/NPR/__init__.py b/bitmind/generation/util/__init__.py
similarity index 100%
rename from base_miner/NPR/__init__.py
rename to bitmind/generation/util/__init__.py
diff --git a/bitmind/synthetic_data_generation/image_utils.py b/bitmind/generation/util/image.py
similarity index 54%
rename from bitmind/synthetic_data_generation/image_utils.py
rename to bitmind/generation/util/image.py
index f3ad50be..04042f3b 100644
--- a/bitmind/synthetic_data_generation/image_utils.py
+++ b/bitmind/generation/util/image.py
@@ -2,12 +2,12 @@
import PIL
import os
from PIL import Image, ImageDraw
-from typing import Tuple
+from typing import Tuple, Union, List
-from bitmind.validator.config import TARGET_IMAGE_SIZE
-
-def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL.Image.Image:
+def resize_image(
+ image: PIL.Image.Image, max_width: int, max_height: int
+) -> PIL.Image.Image:
"""Resize the image to fit within specified dimensions while maintaining aspect ratio."""
original_width, original_height = image.size
@@ -25,7 +25,7 @@ def resize_image(image: PIL.Image.Image, max_width: int, max_height: int) -> PIL
return resized_image
-def resize_images_in_directory(directory, target_width=TARGET_IMAGE_SIZE[0], target_height=TARGET_IMAGE_SIZE[1]):
+def resize_images_in_directory(directory, target_width, target_height):
"""
Resize all images in the specified directory to the target width and height.
@@ -36,27 +36,35 @@ def resize_images_in_directory(directory, target_width=TARGET_IMAGE_SIZE[0], tar
"""
# List all files in the directory
for filename in os.listdir(directory):
- if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): # Check for image file extensions
+ if filename.endswith(
+ (".png", ".jpg", ".jpeg", ".bmp", ".gif")
+ ): # Check for image file extensions
filepath = os.path.join(directory, filename)
with PIL.Image.open(filepath) as img:
# Resize the image and save back to the file location
- resized_img = resize_image(img, max_width=target_width, max_height=target_height)
+ resized_img = resize_image(
+ img, max_width=target_width, max_height=target_height
+ )
resized_img.save(filepath)
-
-def save_images_to_disk(image_dataset, start_index, num_images, save_directory, resize=True):
+
+def save_images_to_disk(
+ image_dataset, start_index, num_images, save_directory, resize=True
+):
if not os.path.exists(save_directory):
os.makedirs(save_directory)
for i in range(start_index, start_index + num_images):
try:
image_data = image_dataset[i] # Retrieve image using the __getitem__ method
- image = image_data['image'] # Extract the image
- image_id = image_data['id'] # Extract the image ID
- file_path = os.path.join(save_directory, f"{image_id}.jpg") # Construct file path
- if resize:
- image = resize_image(image, TARGET_IMAGE_SIZE[0], TARGET_IMAGE_SIZE[1])
- image.save(file_path, 'JPEG') # Save the image
+ image = image_data["image"] # Extract the image
+ image_id = image_data["id"] # Extract the image ID
+ file_path = os.path.join(
+ save_directory, f"{image_id}.jpg"
+ ) # Construct file path
+ # if resize:
+ # image = resize_image(image, TARGET_IMAGE_SIZE[0], TARGET_IMAGE_SIZE[1])
+ image.save(file_path, "JPEG") # Save the image
print(f"Saved: {file_path}")
except Exception as e:
print(f"Failed to save image {i}: {e}")
@@ -67,50 +75,53 @@ def create_random_mask(size: Tuple[int, int]) -> Image.Image:
Create a random mask for i2i transformation.
"""
w, h = size
- mask = Image.new('RGB', size, 'black')
+ mask = Image.new("RGB", size, "black")
if np.random.rand() < 0.5:
# Rectangular mask with smoother edges
- width = np.random.randint(w//4, w//2)
- height = np.random.randint(h//4, h//2)
+ width = np.random.randint(w // 4, w // 2)
+ height = np.random.randint(h // 4, h // 2)
# Center the rectangle with some random offset
- x = (w - width) // 2 + np.random.randint(-width//4, width//4)
- y = (h - height) // 2 + np.random.randint(-height//4, height//4)
+ x = (w - width) // 2 + np.random.randint(-width // 4, width // 4)
+ y = (h - height) // 2 + np.random.randint(-height // 4, height // 4)
# Create mask with PIL draw for smoother edges
draw = ImageDraw.Draw(mask)
draw.rounded_rectangle(
[x, y, x + width, y + height],
radius=min(width, height) // 10, # Smooth corners
- fill='white'
+ fill="white",
)
else:
# Circular mask with feathered edges
draw = ImageDraw.Draw(mask)
- x = w//2
- y = h//2
+ x = w // 2
+ y = h // 2
# Make radius proportional to image size
radius = min(w, h) // 4
# Add small random offset to center
- x += np.random.randint(-radius//4, radius//4)
- y += np.random.randint(-radius//4, radius//4)
+ x += np.random.randint(-radius // 4, radius // 4)
+ y += np.random.randint(-radius // 4, radius // 4)
# Draw multiple circles with decreasing opacity for feathered edge
- for r in range(radius, radius-10, -1):
- opacity = int(255 * (r - (radius-10)) / 10)
- draw.ellipse(
- [x-r, y-r, x+r, y+r],
- fill=(255, 255, 255, opacity)
- )
+ for r in range(radius, radius - 10, -1):
+ opacity = int(255 * (r - (radius - 10)) / 10)
+ draw.ellipse([x - r, y - r, x + r, y + r], fill=(255, 255, 255, opacity))
return mask, (x, y)
-def is_black_image(img: Image.Image, threshold: int = 10) -> bool:
+
+def is_black_output(
+ modality: str, output: Union[List[Image.Image], Image.Image], threshold: int = 10
+) -> bool:
"""
- Returns True if the image is (almost) completely black.
+ Returns True if the image or frames are (almost) completely black.
"""
- arr = np.array(img)
- return np.mean(arr) < threshold
\ No newline at end of file
+ if modality == "image":
+ arr = np.array(output[modality].images[0])
+ return np.mean(arr) < threshold
+ elif modality == "video":
+ return np.all([np.mean(np.array(arr)) < threshold for arr in output[modality].frames[0]])
diff --git a/bitmind/validator/model_utils.py b/bitmind/generation/util/model.py
similarity index 68%
rename from bitmind/validator/model_utils.py
rename to bitmind/generation/util/model.py
index 5159d83d..13acae0a 100644
--- a/bitmind/validator/model_utils.py
+++ b/bitmind/generation/util/model.py
@@ -1,29 +1,31 @@
-import torch
+import PIL.Image
import numpy as np
-from diffusers import MotionAdapter, HunyuanVideoTransformer3DModel, DiffusionPipeline
+import torch
+import bittensor as bt
+from diffusers import (
+ DiffusionPipeline,
+ HunyuanVideoTransformer3DModel,
+ MotionAdapter,
+)
from huggingface_hub import hf_hub_download
+from janus.models import VLChatProcessor
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM
-from janus.models import VLChatProcessor
-import PIL.Image
-from typing import Dict, Any, Any, Optional
-import bittensor as bt
+from typing import Any, Dict, Optional
def load_hunyuanvideo_transformer(
- model_id: str = "tencent/HunyuanVideo",
- subfolder: str = "transformer",
- torch_dtype: torch.dtype = torch.bfloat16,
- revision: str = 'refs/pr/18'
+ model_id: str = "tencent/HunyuanVideo",
+ subfolder: str = "transformer",
+ torch_dtype: torch.dtype = torch.bfloat16,
+ revision: str = "refs/pr/18",
):
return HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder=subfolder, torch_dtype=torch_dtype, revision=revision
)
-def load_annimatediff_motion_adapter(
- step: int = 4
-) -> MotionAdapter:
+def load_annimatediff_motion_adapter(step: int = 4) -> MotionAdapter:
"""
Load a motion adapter model for AnimateDiff.
@@ -45,12 +47,7 @@ def load_annimatediff_motion_adapter(
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
- adapter.load_state_dict(
- load_file(
- hf_hub_download(repo, ckpt),
- device=device
- )
- )
+ adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
return adapter
@@ -61,9 +58,7 @@ def __init__(self, model, processor):
self.processor = processor
self.tokenizer = self.processor.tokenizer
self.register_modules(
- model=model,
- processor=processor,
- tokenizer=self.processor.tokenizer
+ model=model, processor=processor, tokenizer=self.processor.tokenizer
)
@torch.inference_mode()
@@ -76,7 +71,7 @@ def __call__(
image_token_num_per_image: int = 576,
img_size: int = 384,
patch_size: int = 16,
- **kwargs
+ **kwargs,
):
conversation = [
{
@@ -96,41 +91,47 @@ def __call__(
input_ids = self.processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids).to(self.device)
- tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).to(self.device)
- for i in range(parallel_size*2):
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(
+ self.device
+ )
+ for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = self.processor.pad_id
inputs_embeds = self.model.language_model.get_input_embeddings()(tokens)
- generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(self.device)
+ generated_tokens = torch.zeros(
+ (parallel_size, image_token_num_per_image), dtype=torch.int
+ ).to(self.device)
outputs = None
for i in range(image_token_num_per_image):
outputs = self.model.language_model.model(
- inputs_embeds=inputs_embeds,
- use_cache=True,
- past_key_values=outputs.past_key_values if i != 0 else None
+ inputs_embeds=inputs_embeds,
+ use_cache=True,
+ past_key_values=outputs.past_key_values if i != 0 else None,
)
hidden_states = outputs.last_hidden_state
-
+
logits = self.model.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
-
- logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
+
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
- next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
+ next_token = torch.cat(
+ [next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
+ ).view(-1)
img_embeds = self.model.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
dec = self.model.gen_vision_model.decode_code(
- generated_tokens.to(dtype=torch.int),
- shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]
+ generated_tokens.to(dtype=torch.int),
+ shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size],
)
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
@@ -138,7 +139,7 @@ def __call__(
images = []
for i in range(parallel_size):
images.append(PIL.Image.fromarray(dec[i].astype(np.uint8)))
-
+
# Return object with images attribute
class Output:
def __init__(self, images):
@@ -150,7 +151,7 @@ def __init__(self, images):
def from_pretrained(cls, model_path, **kwargs):
model, processor = load_janus_model(model_path, **kwargs)
return cls(model=model, processor=processor)
-
+
def to(self, device):
self.model = self.model.to(device)
return self
@@ -158,85 +159,87 @@ def to(self, device):
def load_janus_model(model_path: str, **kwargs):
processor = VLChatProcessor.from_pretrained(model_path)
-
+
# Filter kwargs to only include what Janus expects
janus_kwargs = {
- 'trust_remote_code': True,
- 'torch_dtype': kwargs.get('torch_dtype', torch.bfloat16)
+ "trust_remote_code": True,
+ "torch_dtype": kwargs.get("torch_dtype", torch.bfloat16),
}
-
+
# Let device placement be handled by diffusers like other models
- model = AutoModelForCausalLM.from_pretrained(
- model_path,
- **janus_kwargs
- ).eval()
-
+ model = AutoModelForCausalLM.from_pretrained(model_path, **janus_kwargs).eval()
+
return model, processor
def create_pipeline_generator(model_config: Dict[str, Any], model: Any) -> callable:
"""
Creates a generator function based on pipeline configuration.
-
+
Args:
model_config: Model configuration dictionary
model: Loaded model instance(s)
-
+
Returns:
Callable that handles the generation process for the model
"""
- if isinstance(model_config.get('pipeline_stages'), list):
+ if isinstance(model_config.get("pipeline_stages"), list):
+
def generate(prompt: str, **kwargs):
output = None
prompt_embeds = None
negative_embeds = None
-
- for stage in model_config['pipeline_stages']:
+
+ for stage in model_config["pipeline_stages"]:
stage_args = {**kwargs} # Copy base args
-
+
# Add stage-specific args
- if stage.get('input_key') and output is not None:
- stage_args[stage['input_key']] = output
-
+ if stage.get("input_key") and output is not None:
+ stage_args[stage["input_key"]] = output
+
# Add any stage-specific generation args
- if stage.get('args'):
- stage_args.update(stage['args'])
-
+ if stage.get("args"):
+ stage_args.update(stage["args"])
+
# Handle prompt embeddings
- if stage.get('use_prompt_embeds') and prompt_embeds is not None:
- stage_args['prompt_embeds'] = prompt_embeds
- stage_args['negative_prompt_embeds'] = negative_embeds
- stage_args.pop('prompt', None)
- elif stage.get('save_prompt_embeds'):
+ if stage.get("use_prompt_embeds") and prompt_embeds is not None:
+ stage_args["prompt_embeds"] = prompt_embeds
+ stage_args["negative_prompt_embeds"] = negative_embeds
+ stage_args.pop("prompt", None)
+ elif stage.get("save_prompt_embeds"):
# Get embeddings directly from encode_prompt
- prompt_embeds, negative_embeds = model[stage['name']].encode_prompt(
+ prompt_embeds, negative_embeds = model[stage["name"]].encode_prompt(
prompt=prompt,
- device=model[stage['name']].device,
- num_images_per_prompt=stage_args.get('num_images_per_prompt', 1),
+ device=model[stage["name"]].device,
+ num_images_per_prompt=stage_args.get(
+ "num_images_per_prompt", 1
+ ),
)
- stage_args['prompt_embeds'] = prompt_embeds
- stage_args['negative_prompt_embeds'] = negative_embeds
- stage_args.pop('prompt', None)
+ stage_args["prompt_embeds"] = prompt_embeds
+ stage_args["negative_prompt_embeds"] = negative_embeds
+ stage_args.pop("prompt", None)
else:
- stage_args['prompt'] = prompt
-
+ stage_args["prompt"] = prompt
+
# Run stage
- result = model[stage['name']](**stage_args)
-
+ result = model[stage["name"]](**stage_args)
+
# Extract output based on stage config
- output = getattr(result, stage.get('output_attr', 'images'))
-
+ output = getattr(result, stage.get("output_attr", "images"))
+
# Clear memory if configured
- if model_config.get('clear_memory_on_stage_end'):
+ if model_config.get("clear_memory_on_stage_end"):
import gc
import torch
+
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
-
+
return result
+
return generate
-
+
# Default single-stage pipeline
return lambda prompt, **kwargs: model(prompt=prompt, **kwargs)
@@ -253,7 +256,7 @@ def enable_model_optimizations(
) -> None:
"""
Enables various model optimizations for better memory usage and performance.
-
+
Args:
model: The model to optimize
device: Device to move model to ('cuda', 'cpu', etc)
@@ -265,41 +268,45 @@ def enable_model_optimizations(
stage_name: Optional name of pipeline stage for logging
"""
model_name = f"{stage_name} " if stage_name else ""
-
+
if disable_progress_bar:
- bt.logging.info(f"Disabling progress bar for {model_name}model")
+ bt.logging.debug(f"Disabling progress bar for {model_name}model")
model.set_progress_bar_config(disable=True)
# Handle CPU offloading
if enable_cpu_offload:
- bt.logging.info(f"Enabling CPU offload for {model_name}model")
+ bt.logging.debug(f"Enabling CPU offload for {model_name}model")
model.enable_model_cpu_offload(device=device)
elif enable_sequential_cpu_offload:
- bt.logging.info(f"Enabling sequential CPU offload for {model_name}model")
+ bt.logging.debug(f"Enabling sequential CPU offload for {model_name}model")
model.enable_sequential_cpu_offload()
else:
# Only move to device if not using CPU offload
- bt.logging.info(f"Moving {model_name}model to {device}")
+ bt.logging.debug(f"Moving {model_name}model to {device}")
model.to(device)
# Handle VAE optimizations if not using CPU offload
if not enable_cpu_offload:
if enable_vae_slicing:
- bt.logging.info(f"Enabling VAE slicing for {model_name}model")
+ bt.logging.debug(f"Enabling VAE slicing for {model_name}model")
try:
model.vae.enable_slicing()
except Exception:
try:
model.enable_vae_slicing()
except Exception as e:
- bt.logging.warning(f"Failed to enable VAE slicing for {model_name}model: {e}")
+ bt.logging.warning(
+ f"Failed to enable VAE slicing for {model_name}model: {e}"
+ )
if enable_vae_tiling:
- bt.logging.info(f"Enabling VAE tiling for {model_name}model")
+ bt.logging.debug(f"Enabling VAE tiling for {model_name}model")
try:
model.vae.enable_tiling()
except Exception:
try:
model.enable_vae_tiling()
except Exception as e:
- bt.logging.warning(f"Failed to enable VAE tiling for {model_name}model: {e}")
+ bt.logging.warning(
+ f"Failed to enable VAE tiling for {model_name}model: {e}"
+ )
diff --git a/bitmind/synthetic_data_generation/prompt_utils.py b/bitmind/generation/util/prompt.py
similarity index 67%
rename from bitmind/synthetic_data_generation/prompt_utils.py
rename to bitmind/generation/util/prompt.py
index f2407f88..e412b35a 100644
--- a/bitmind/synthetic_data_generation/prompt_utils.py
+++ b/bitmind/generation/util/prompt.py
@@ -1,22 +1,26 @@
def get_tokenizer_with_min_len(model):
"""
Returns the tokenizer with the smallest maximum token length.
-
+
Args:
model: Single pipeline or dict of pipeline stages.
-
+
Returns:
tuple: (tokenizer, max_token_length)
"""
# Get the model to check for tokenizers
- pipeline = model['stage1'] if isinstance(model, dict) else model
-
+ pipeline = model["stage1"] if isinstance(model, dict) else model
+
# If model has two tokenizers, return the one with smaller max length
- if hasattr(pipeline, 'tokenizer_2'):
+ if hasattr(pipeline, "tokenizer_2"):
len_1 = pipeline.tokenizer.model_max_length
len_2 = pipeline.tokenizer_2.model_max_length
- return (pipeline.tokenizer_2, len_2) if len_2 < len_1 else (pipeline.tokenizer, len_1)
-
+ return (
+ (pipeline.tokenizer_2, len_2)
+ if len_2 < len_1
+ else (pipeline.tokenizer, len_1)
+ )
+
return pipeline.tokenizer, pipeline.tokenizer.model_max_length
@@ -31,12 +35,13 @@ def truncate_prompt_if_too_long(prompt: str, model):
str: The original prompt if within the token limit; otherwise, a truncated version of the prompt.
"""
tokenizer, max_token_len = get_tokenizer_with_min_len(model)
- tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings
- if len(tokens['input_ids']) < max_token_len:
+ tokens = tokenizer(prompt, verbose=False) # Suppress token max exceeded warnings
+ if len(tokens["input_ids"]) < max_token_len:
return prompt
# Truncate tokens if they exceed the maximum token length, decode the tokens back to a string
- truncated_prompt = tokenizer.decode(token_ids=tokens['input_ids'][:max_token_len-1],
- skip_special_tokens=True)
+ truncated_prompt = tokenizer.decode(
+ token_ids=tokens["input_ids"][: max_token_len - 1], skip_special_tokens=True
+ )
tokens = tokenizer(truncated_prompt)
- return truncated_prompt
\ No newline at end of file
+ return truncated_prompt
diff --git a/bitmind/metagraph.py b/bitmind/metagraph.py
new file mode 100644
index 00000000..5fcde545
--- /dev/null
+++ b/bitmind/metagraph.py
@@ -0,0 +1,108 @@
+import time
+import asyncio
+from typing import Callable, List, Tuple
+import numpy as np
+import bittensor as bt
+from bittensor.utils.weight_utils import process_weights_for_netuid
+
+from bitmind.utils import fail_with_none
+
+import threading
+
+
+def get_miner_uids(
+ metagraph: "bt.metagraph", self_uid: int, vpermit_tao_limit: int
+) -> List[int]:
+ available_uids = []
+ for uid in range(int(metagraph.n.item())):
+ if uid == self_uid:
+ continue
+
+ # Filter non serving axons.
+ if not metagraph.axons[uid].is_serving:
+ continue
+ # Filter validator permit > 1024 stake.
+ if metagraph.validator_permit[uid]:
+ if metagraph.S[uid] > vpermit_tao_limit:
+ continue
+ available_uids.append(uid)
+ continue
+ return available_uids
+
+
+def create_set_weights(version: int, netuid: int):
+ @fail_with_none("Failed setting weights")
+ def set_weights(
+ wallet: "bt.wallet",
+ metagraph: "bt.metagraph",
+ subtensor: "bt.subtensor",
+ weights: Tuple[List[int], List[float]],
+ ):
+ uids, raw_weights = weights
+ if not len(uids):
+ bt.logging.info("No UIDS to score")
+ return
+
+ # Set the weights on chain via our subtensor connection.
+ (
+ processed_weight_uids,
+ processed_weights,
+ ) = process_weights_for_netuid(
+ uids=np.asarray(uids),
+ weights=np.asarray(raw_weights),
+ netuid=netuid,
+ subtensor=subtensor,
+ metagraph=metagraph,
+ )
+
+ bt.logging.info("Setting Weights: " + str(processed_weights))
+ bt.logging.info("Weight Uids: " + str(processed_weight_uids))
+ for _ in range(3):
+ result, message = subtensor.set_weights(
+ wallet=wallet,
+ netuid=netuid,
+ uids=processed_weight_uids, # type: ignore
+ weights=processed_weights,
+ wait_for_finalization=False,
+ wait_for_inclusion=False,
+ version_key=version,
+ max_retries=1,
+ )
+ if result is True:
+ bt.logging.success("set_weights on chain successfully!")
+ break
+ else:
+ bt.logging.error(f"set_weights failed {message}")
+ time.sleep(15)
+
+ return set_weights
+
+
+def create_subscription_handler(substrate, callback: Callable):
+ def inner(obj, update_nr, _):
+ substrate.get_block(block_number=obj["header"]["number"])
+
+ if update_nr >= 1:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ return loop.run_until_complete(callback(obj["header"]["number"]))
+
+ return inner
+
+
+def start_subscription(substrate, callback: Callable):
+ return substrate.subscribe_block_headers(
+ create_subscription_handler(substrate, callback)
+ )
+
+
+def run_block_callback_thread(substrate, callback: Callable):
+ try:
+ subscription_thread = threading.Thread(
+ target=start_subscription, args=[substrate, callback], daemon=True
+ )
+ subscription_thread.start()
+ bt.logging.info("Block subscription started in background thread.")
+ return subscription_thread
+ except Exception as e:
+ bt.logging.error(f"faaailuuure {callback} - {e}")
diff --git a/bitmind/protocol.py b/bitmind/protocol.py
deleted file mode 100644
index fc05d43b..00000000
--- a/bitmind/protocol.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# developer: dubm
-# Copyright © 2023 Bitmind
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-from typing import List, Union
-from pydantic import BaseModel, Field
-from torchvision import transforms
-from io import BytesIO
-from PIL import Image
-import bittensor as bt
-import numpy as np
-import base64
-import pydantic
-import torch
-import zlib
-
-from bitmind.validator.config import TARGET_IMAGE_SIZE
-from bitmind.utils.image_transforms import get_base_transforms
-
-base_transforms = get_base_transforms(TARGET_IMAGE_SIZE)
-
-
-# ---- miner ----
-# Example usage:
-# def miner_forward( synapse: ImageSynapse ) -> ImageSynapse:
-# ...
-# synapse.predictions = deepfake_detection_model_outputs
-# return synapse
-# axon = bt.axon().attach( miner_forward ).serve(netuid=...).start()
-
-# ---- validator ---
-# Example usage:
-# dendrite = bt.dendrite()
-# b64_images = [b64_img_1, ..., b64_img_n]
-# predictions = dendrite.query( ImageSynapse( images = b64_images ) )
-# assert len(predictions) == len(b64_images)
-
-
-def prepare_synapse(input_data, modality):
- if isinstance(input_data, torch.Tensor):
- input_data = transforms.ToPILImage()(input_data.cpu().detach())
- if isinstance(input_data, list) and isinstance(input_data[0], torch.Tensor):
- for i, img in enumerate(input_data):
- input_data[i] = transforms.ToPILImage()(img.cpu().detach())
-
- if modality == "image":
- return prepare_image_synapse(input_data)
- elif modality == "video":
- return prepare_video_synapse(input_data)
- else:
- raise NotImplementedError(f"Unsupported modality: {modality}")
-
-
-def prepare_image_synapse(image: Image):
- """
- Prepares an image for use with ImageSynapse object.
-
- Args:
- image (Image): The input image to be prepared.
-
- Returns:
- ImageSynapse: An instance of ImageSynapse containing the encoded image and a default prediction value.
- """
- image_bytes = BytesIO()
- image.save(image_bytes, format="JPEG")
- b64_encoded_image = base64.b64encode(image_bytes.getvalue())
- return ImageSynapse(image=b64_encoded_image)
-
-
-def prepare_video_synapse(frames: List[Image.Image]):
- """
- Prepares video frames for use with VideoSynapse object.
-
- Args:
- frames (List[Image.Image]): The list of video frames to be prepared.
-
- Returns:
- VideoSynapse: An instance of VideoSynapse containing the encoded frames and a default prediction value.
- """
- frame_bytes = []
- for frame in frames:
- buffer = BytesIO()
- frame.save(buffer, format="JPEG")
- frame_bytes.append(buffer.getvalue())
-
- combined_bytes = b"".join(frame_bytes)
- compressed_data = zlib.compress(combined_bytes)
- encoded_data = base64.b85encode(compressed_data).decode("utf-8")
- return VideoSynapse(video=encoded_data)
-
-
-class MediaSynapse(bt.Synapse):
-
- testnet_label: int = -1 # for miners to monitor their performance on testnet
-
- prediction: Union[float, List[float]] = pydantic.Field(
- title="Prediction",
- description="Probability vector for [real, synthetic, semi-synthetic] classes.",
- default=[-1.0, -1.0, -1.0],
- frozen=False,
- )
-
- def deserialize(self) -> np.ndarray:
- """
- Deserialize the output. Backwards compatible with binary float outputs.
-
- Returns:
- - float: The deserialized miner prediction probabilities
- """
- p = self.prediction
- if isinstance(p, float):
- if p == -1:
- return np.array([-1.0, -1.0, -1.0])
- else:
- return np.array([1 - p, p, 0.0])
- elif isinstance(p, list):
- if len(p) == 2:
- p += [0.0] # assume 2-dim responses are [real, fake]
- return np.array(p)
- else:
- raise ValueError(f"Unsupported prediction type: {type(p)}")
-
-
-class ImageSynapse(MediaSynapse):
- """
- This protocol helps in handling image/prediction request and response communication between
- the miner and the validator.
-
- Attributes:
- - image: a bas64 encoded images
- - prediction: a float indicating the probabilty that the image is AI generated/modified.
- >.5 is considered generated/modified, <= 0.5 is considered real.
- """
-
- image: str = pydantic.Field(
- title="Image", description="A base64 encoded image", default="", frozen=False
- )
-
-
-class VideoSynapse(MediaSynapse):
- """
- Naive initial VideoSynapse (Epistula version coming soon I promise)
- """
-
- # Required request input, filled by sending dendrite caller.
- video: str = pydantic.Field(
- title="Video",
- description="A wildly inefficient means of sending video data",
- default="",
- frozen=False,
- )
-
-
-def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]:
- """
- V1 of a function for decoding a VideoSynapse object back into a list of torch tensors.
-
- Args:
- synapse: VideoSynapse object containing the encoded video data
-
- Returns:
- List of torch tensors, each representing a frame from the video
- """
- compressed_data = base64.b85decode(synapse.video.encode("utf-8"))
- combined_bytes = zlib.decompress(compressed_data)
-
- # Split the combined bytes into individual JPEG files
- # Look for JPEG markers: FF D8 (start) and FF D9 (end)
- frames = []
- current_pos = 0
- data_length = len(combined_bytes)
-
- while current_pos < data_length:
- # Find start of JPEG (FF D8)
- while current_pos < data_length - 1:
- if (
- combined_bytes[current_pos] == 0xFF
- and combined_bytes[current_pos + 1] == 0xD8
- ):
- break
- current_pos += 1
-
- if current_pos >= data_length - 1:
- break
-
- start_pos = current_pos
-
- # Find end of JPEG (FF D9)
- while current_pos < data_length - 1:
- if (
- combined_bytes[current_pos] == 0xFF
- and combined_bytes[current_pos + 1] == 0xD9
- ):
- current_pos += 2
- break
- current_pos += 1
-
- if current_pos > start_pos:
- # Extract the JPEG data
- jpeg_data = combined_bytes[start_pos:current_pos]
- try:
- img = Image.open(BytesIO(jpeg_data))
- frames.append(img)
- except Exception as e:
- print(f"Error processing frame: {e}")
- continue
-
- bt.logging.info("transforming video inputs")
- frames = base_transforms(frames)
-
- frames = torch.stack(frames, dim=0)
- frames = frames.unsqueeze(0)
- print(f"decoded video into tensor with shape {frames.shape}")
- return frames
diff --git a/bitmind/scoring/__init__.py b/bitmind/scoring/__init__.py
new file mode 100644
index 00000000..ac1a0dd3
--- /dev/null
+++ b/bitmind/scoring/__init__.py
@@ -0,0 +1 @@
+from .eval_engine import EvalEngine
diff --git a/bitmind/scoring/eval_engine.py b/bitmind/scoring/eval_engine.py
new file mode 100644
index 00000000..3004bb39
--- /dev/null
+++ b/bitmind/scoring/eval_engine.py
@@ -0,0 +1,349 @@
+from typing import List, Dict, Tuple, Any, Optional
+import bittensor as bt
+import numpy as np
+import json
+from sklearn.metrics import matthews_corrcoef
+import os
+
+from bitmind.types import Modality
+from bitmind.scoring.miner_history import MinerHistory
+
+
+class EvalEngine:
+ """
+ A class to track rewards and compute weights for miners based on their
+ prediction performance.
+ """
+
+ def __init__(
+ self,
+ metagraph: bt.metagraph,
+ config: bt.config,
+ ):
+ assert config.neuron.full_path
+ assert (
+ abs(config.scoring.image_weight + config.scoring.video_weight - 1.0) < 1e-6
+ ), "Modality weights must sum to 1.0"
+ assert (
+ abs(config.scoring.binary_weight + config.scoring.multiclass_weight - 1.0)
+ < 1e-6
+ ), "Binary/Multiclass weights must sum to 1.0"
+
+ self.metagraph = metagraph
+ self.config = config
+ self.scores = np.zeros(self.metagraph.n, dtype=np.float32)
+ self.tracker = MinerHistory()
+ self.miner_metrics = {}
+ self.load_state(save_dir=self.config.neuron.full_path)
+
+ def get_weights(self):
+ """Returns an L1 normalized vector of scores (rewards EMA)."""
+
+ if np.isnan(self.scores).any():
+ bt.logging.warning(
+ f"Scores contain NaN values. This may be due to a lack of responses from miners, or a bug in your reward functions."
+ )
+
+ norm = np.linalg.norm(self.scores, ord=1, axis=0, keepdims=True)
+ if np.any(norm == 0) or np.isnan(norm).any():
+ norm = np.ones_like(norm) # Avoid division by zero or NaN
+
+ return self.scores / norm
+
+ def score_challenge(
+ self,
+ challenge_results: dict,
+ label: int,
+ challenge_modality: Modality,
+ ) -> Tuple[np.ndarray, List[Dict[Modality, Dict[str, float]]]]:
+ """Update miner prediction history, compute instantaneous rewards, update score EMA"""
+
+ predictions = [np.array(r["prediction"]) for r in challenge_results]
+ hotkeys = [r["hotkey"] for r in challenge_results]
+ uids = [r["uid"] for r in challenge_results]
+
+ rewards = self._get_rewards_for_challenge(
+ label, predictions, uids, hotkeys, challenge_modality
+ )
+ self._update_scores(rewards)
+ return rewards
+
+ def _update_scores(self, rewards: dict):
+ """Performs exponential moving average on the scores based on the rewards received from the miners."""
+
+ uids = list(rewards.keys())
+ rewards = np.array([rewards[uid] for uid in uids])
+ bt.logging.trace(f"updating scores {uids} : {rewards}")
+
+ if np.isnan(rewards).any():
+ bt.logging.warning(f"NaN values detected in rewards: {rewards}")
+ rewards = np.nan_to_num(rewards, nan=0)
+
+ rewards = np.asarray(rewards)
+ if isinstance(uids, np.ndarray):
+ uids_array = uids.copy()
+ else:
+ uids_array = np.array(uids)
+
+ if rewards.size == 0 or uids_array.size == 0:
+ bt.logging.warning(
+ "Either rewards or uids_array is empty. No updates will be performed."
+ )
+ return
+
+ if rewards.size != uids_array.size:
+ raise ValueError(
+ f"Shape mismatch: rewards array of shape {rewards.shape} "
+ f"cannot be broadcast to uids array of shape {uids_array.shape}"
+ )
+
+ # Compute forward pass rewards, assumes uids are mutually exclusive.
+ # shape: [ metagraph.n ]
+ self.maybe_extend_scores(np.max(uids) + 1)
+ scattered_rewards: np.ndarray = np.full_like(self.scores, 0.5)
+ vali_uids = [
+ uid
+ for uid in range(len(scattered_rewards))
+ if self.metagraph.validator_permit[uid]
+ and self.metagraph.S[uid] > self.config.vpermit_tao_limit
+ ]
+ no_response_uids = [
+ uid
+ for uid in range(len(scattered_rewards))
+ if all(
+ [
+ count == 0
+ for modality, count in self.tracker.get_prediction_count(
+ uid
+ ).items()
+ ]
+ )
+ ]
+ scattered_rewards[vali_uids] = 0.0
+ scattered_rewards[no_response_uids] = 0.0
+ scattered_rewards[uids_array] = rewards
+ bt.logging.debug(f"Scattered rewards: {rewards}")
+
+ # Update scores with rewards produced by this step.
+ # shape: [ metagraph.n ]
+ alpha: float = self.config.scoring.moving_average_alpha
+ self.scores: np.ndarray = alpha * scattered_rewards + (1 - alpha) * self.scores
+ bt.logging.debug(f"Updated moving avg scores: {self.scores}")
+
+ def _get_rewards_for_challenge(
+ self,
+ label: int,
+ responses: List[np.ndarray],
+ uids: List[int],
+ hotkeys: List[bt.axon],
+ challenge_modality: Modality,
+ ) -> Tuple[np.ndarray, List[Dict[Modality, Dict[str, float]]]]:
+ """
+ Calculate rewards for miner responses based on performance metrics.
+
+ Args:
+ label: The true label (0 for real, 1 for synthetic, 2 for semi-synthetic)
+ responses: List of probability vectors from miners, each shape (3,)
+ uids: List of miner UIDs
+ axons: List of miner axons
+ challenge_modality: Type of challenge (Modality.VIDEO or Modality.IMAGE)
+
+ Returns:
+ Tuple containing:
+ - np.ndarray: Array of rewards for each miner
+ - List[Dict]: List of performance metrics for each miner
+ """
+ miner_rewards = {}
+ for hotkey, uid, pred_probs in zip(hotkeys, uids, responses):
+ miner_modality_rewards = {}
+ miner_modality_metrics = {}
+
+ self.tracker.update(uid, pred_probs, label, challenge_modality, hotkey)
+
+ for modality in Modality:
+ try:
+ modality = modality.value
+ pred_count = self.tracker.get_prediction_count(uid).get(modality, 0)
+ if pred_count < 5:
+ miner_modality_rewards[modality] = 0.0
+ miner_modality_metrics[modality] = self._empty_metrics()
+ continue
+
+ metrics = self._get_metrics(uid, modality, window=100)
+
+ binary_weight = self.config.scoring.binary_weight
+ multiclass_weight = self.config.scoring.multiclass_weight
+ reward = (
+ binary_weight * metrics["binary_mcc"]
+ + multiclass_weight * metrics["multi_class_mcc"]
+ )
+
+ if modality == challenge_modality:
+ reward *= self.compute_penalty_multiplier(pred_probs)
+
+ miner_modality_rewards[modality] = reward
+ miner_modality_metrics[modality] = metrics
+
+ except Exception as e:
+ bt.logging.error(
+ f"Couldn't calculate reward for miner {uid}, "
+ f"prediction: {pred_probs}, label: {label}, modality: {modality}"
+ )
+ bt.logging.exception(e)
+ miner_modality_rewards[modality] = 0.0
+ miner_modality_metrics[modality] = self._empty_metrics()
+
+ image_weight = self.config.scoring.image_weight
+ video_weight = self.config.scoring.video_weight
+ image_rewards = miner_modality_rewards.get(Modality.IMAGE, 0.0)
+ video_rewards = miner_modality_rewards.get(Modality.VIDEO, 0.0)
+ total_reward = (image_weight * image_rewards) + (
+ video_weight * video_rewards
+ )
+
+ miner_rewards[uid] = total_reward
+ self.miner_metrics[uid] = miner_modality_metrics
+
+ return miner_rewards
+
+ def _get_metrics(
+ self, uid: int, modality: Modality, window: Optional[int] = None
+ ) -> Dict[str, float]:
+ """
+ Calculate performance metrics for a specific miner and modality.
+
+ Args:
+ uid: The miner's UID
+ modality: The modality to calculate metrics for
+ window: Number of recent predictions to consider (default: None = all)
+
+ Returns:
+ Dict containing performance metrics
+ """
+ if (
+ uid not in self.tracker.predictions
+ or modality not in self.tracker.predictions[uid]
+ ):
+ return self._empty_metrics()
+
+ recent_preds = list(self.tracker.predictions[uid][modality])
+ recent_labels = list(self.tracker.labels[uid][modality])
+
+ if len(recent_labels) != len(recent_preds):
+ bt.logging.critical(
+ f"Prediction and label array size mismatch ({len(recent_preds)} and {len(recent_labels)})"
+ )
+ bt.logging.critical(
+ f"Clearing miner history for {uid} to allow scoring to resume"
+ )
+ self.tracker.reset_miner_history(uid)
+ return self._empty_metrics()
+
+ if window is not None:
+ window = min(window, len(recent_preds))
+ recent_preds = recent_preds[-window:]
+ recent_labels = recent_labels[-window:]
+
+ pred_probs = np.array([p for p in recent_preds if not np.array_equal(p, -1)])
+ labels = np.array(
+ [
+ l
+ for i, l in enumerate(recent_labels)
+ if not np.array_equal(recent_preds[i], -1)
+ ]
+ )
+
+ if len(labels) == 0 or len(pred_probs) == 0:
+ return self._empty_metrics()
+
+ try:
+ predictions = np.argmax(pred_probs, axis=1)
+
+ # Multi-class MCC (real vs synthetic vs semi-synthetic)
+ multi_class_mcc = matthews_corrcoef(labels, predictions)
+
+ # Binary MCC (real vs any synthetic)
+ binary_labels = (labels > 0).astype(int)
+ binary_preds = (predictions > 0).astype(int)
+ binary_mcc = matthews_corrcoef(binary_labels, binary_preds)
+
+ return {"multi_class_mcc": multi_class_mcc, "binary_mcc": binary_mcc}
+ except Exception as e:
+ bt.logging.warning(f"Error in reward computation: {e}")
+ return self._empty_metrics()
+
+ def get_miner_metrics(self, uid):
+ return self.miner_metrics.get(uid, self._empty_metrics())
+
+ def _empty_metrics(self):
+ """Return a dictionary of empty metrics."""
+ return {"multi_class_mcc": 0, "binary_mcc": 0}
+
+ def sync_to_metagraph(self):
+ """Just zeros out scores for dereg'd miners. MinerHistory class
+ handles clearing predictio history in `update` when a new hotkey
+ is detected"""
+ hotkeys = self.tracker.miner_hotkeys
+ for uid, hotkey in enumerate(hotkeys):
+ if hotkey != self.metagraph.hotkeys[uid]:
+ self.scores[uid] = 0 # hotkey has been replaced
+ self.maybe_extend_scores(self.metagraph.n)
+
+ def maybe_extend_scores(self, n):
+ """Only for the case where metagraph.n is still growing"""
+ if n > len(self.scores):
+ n_before = len(self.scores)
+ new_moving_average = np.zeros((n))
+ new_moving_average[:n_before] = self.scores[:n_before]
+ self.scores = new_moving_average
+
+ def save_state(self, save_dir):
+ os.makedirs(save_dir, exist_ok=True)
+ scores_path = os.path.join(save_dir, "scores.npy")
+ np.save(scores_path, self.scores)
+ self.tracker.save_state(save_dir)
+ bt.logging.trace(f"Saved state to {save_dir}")
+
+ def load_state(self, save_dir):
+ self.tracker.load_state(save_dir)
+ scores_path = os.path.join(save_dir, "scores.npy")
+ if not os.path.isfile(scores_path):
+ bt.logging.info(f"No saved scores found at {scores_path}")
+ return False
+ try:
+ self.scores = np.load(scores_path)
+ return True
+ except Exception as e:
+ bt.logging.error(f"Error deserializing scores: {str(e)}")
+ return False
+
+ @staticmethod
+ def compute_penalty_multiplier(y_pred: np.ndarray) -> float:
+ """
+ Compute penalty for predictions outside valid range.
+
+ Args:
+ y_pred: Predicted probabilities for each class, shape (3,)
+
+ Returns:
+ float: 0.0 if prediction is invalid, 1.0 if valid
+ """
+ sum_check = np.abs(np.sum(y_pred) - 1.0) < 1e-6
+ range_check = np.all((y_pred >= 0.0) & (y_pred <= 1.0))
+ return 1.0 if (sum_check and range_check) else 0.0
+
+ @staticmethod
+ def transform_reward(reward: float, pole: float = 1.01) -> float:
+ """
+ Transform reward using an inverse function.
+
+ Args:
+ reward: Raw reward value
+ pole: Pole parameter for transformation
+
+ Returns:
+ float: Transformed reward
+ """
+ if reward == 0:
+ return 0
+ return 1 / (pole - np.array(reward))
diff --git a/bitmind/scoring/miner_history.py b/bitmind/scoring/miner_history.py
new file mode 100644
index 00000000..4382c178
--- /dev/null
+++ b/bitmind/scoring/miner_history.py
@@ -0,0 +1,111 @@
+from typing import Dict
+from collections import deque
+import bittensor as bt
+import numpy as np
+import joblib
+import traceback
+import os
+
+from bitmind.types import Modality
+
+
+class MinerHistory:
+ """Tracks all recent miner performance to facilitate reward computation."""
+
+ VERSION = 2
+
+ def __init__(self, store_last_n_predictions: int = 100):
+ self.predictions: Dict[int, Dict[Modality, deque]] = {}
+ self.labels: Dict[int, Dict[Modality, deque]] = {}
+ self.miner_hotkeys: Dict[int, str] = {}
+ self.store_last_n_predictions = store_last_n_predictions
+ self.version = self.VERSION
+
+ def update(
+ self,
+ uid: int,
+ prediction: np.ndarray,
+ label: int,
+ modality: Modality,
+ miner_hotkey: str,
+ ):
+ """Update the miner prediction history.
+
+ Args:
+ prediction: numpy array of shape (3,) containing probabilities for
+ [real, synthetic, semi-synthetic]
+ label: integer label (0 for real, 1 for synthetic, 2 for semi-synthetic)
+ """
+ if uid not in self.miner_hotkeys or self.miner_hotkeys[uid] != miner_hotkey:
+ self.reset_miner_history(uid, miner_hotkey)
+ bt.logging.info(f"Reset history for {uid} {miner_hotkey}")
+
+ self.predictions[uid][modality].append(np.array(prediction))
+ self.labels[uid][modality].append(label)
+
+ def _reset_predictions(self, uid: int):
+ self.predictions[uid] = {
+ Modality.IMAGE: deque(maxlen=self.store_last_n_predictions),
+ Modality.VIDEO: deque(maxlen=self.store_last_n_predictions),
+ }
+
+ def _reset_labels(self, uid: int):
+ self.labels[uid] = {
+ Modality.IMAGE: deque(maxlen=self.store_last_n_predictions),
+ Modality.VIDEO: deque(maxlen=self.store_last_n_predictions),
+ }
+
+ def reset_miner_history(self, uid: int, miner_hotkey: str):
+ """Reset the history for a miner."""
+ self._reset_predictions(uid)
+ self._reset_labels(uid)
+ self.miner_hotkeys[uid] = miner_hotkey
+
+ def get_prediction_count(self, uid: int) -> int:
+ """Get the number of predictions made by a specific miner."""
+ counts = {}
+ for modality in [Modality.IMAGE, Modality.VIDEO]:
+ if uid not in self.predictions or modality not in self.predictions[uid]:
+ counts[modality] = 0
+ else:
+ counts[modality] = len(self.predictions[uid][modality])
+ return counts
+
+ def save_state(self, save_dir):
+ path = os.path.join(save_dir, "history.pkl")
+ state = {
+ "version": self.version,
+ "store_last_n_predictions": self.store_last_n_predictions,
+ "miner_hotkeys": self.miner_hotkeys,
+ "predictions": self.predictions,
+ "labels": self.labels,
+ }
+ joblib.dump(state, path)
+
+ def load_state(self, save_dir):
+ path = os.path.join(save_dir, "history.pkl")
+ if not os.path.isfile(path):
+ bt.logging.warning(f"No saved state found at {path}")
+ return False
+
+ try:
+ state = joblib.load(path)
+ if state["version"] != self.VERSION:
+ bt.logging.warning(
+ f"Loading state from different version: {state['version']} != {self.VERSION}"
+ )
+
+ self.version = state["version"]
+ self.store_last_n_predictions = state["store_last_n_predictions"]
+ self.miner_hotkeys = state["miner_hotkeys"]
+ self.predictions = state["predictions"]
+ self.labels = state["labels"]
+ bt.logging.debug(
+ f"Successfully loaded history for {len(self.miner_hotkeys)} miners"
+ )
+ return True
+
+ except Exception as e:
+ bt.logging.error(f"Error deserializing MinerHistory state: {str(e)}")
+ bt.logging.error(traceback.format_exc())
+ return False
diff --git a/bitmind/synthetic_data_generation/README.md b/bitmind/synthetic_data_generation/README.md
deleted file mode 100644
index 80b2e51c..00000000
--- a/bitmind/synthetic_data_generation/README.md
+++ /dev/null
@@ -1,4 +0,0 @@
-
-# Synthetic Image Generation
-
-This folder contains files for the implementation of a joint vision-to-language and text-to-image model system that generates highly diverse and realistic images for deepfake detector training and Subnet 34 validating.
\ No newline at end of file
diff --git a/bitmind/synthetic_data_generation/__init__.py b/bitmind/synthetic_data_generation/__init__.py
deleted file mode 100644
index 5a2682ac..00000000
--- a/bitmind/synthetic_data_generation/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .synthetic_data_generator import SyntheticDataGenerator
-from .prompt_generator import PromptGenerator
diff --git a/bitmind/synthetic_data_generation/synthetic_data_generator.py b/bitmind/synthetic_data_generation/synthetic_data_generator.py
deleted file mode 100644
index 36257ad7..00000000
--- a/bitmind/synthetic_data_generation/synthetic_data_generator.py
+++ /dev/null
@@ -1,586 +0,0 @@
-import gc
-import json
-import os
-import random
-import time
-import warnings
-from pathlib import Path
-from typing import Dict, Optional, Any, Union
-from itertools import zip_longest
-
-import bittensor as bt
-import numpy as np
-import torch
-from diffusers.utils import export_to_video
-from PIL import Image
-
-from bitmind.validator.config import (
- HUGGINGFACE_CACHE_DIR,
- TEXT_MODERATION_MODEL,
- IMAGE_ANNOTATION_MODEL,
- MODELS,
- MODEL_NAMES,
- T2V_MODEL_NAMES,
- T2I_MODEL_NAMES,
- I2I_MODEL_NAMES,
- I2V_MODEL_NAMES,
- TARGET_IMAGE_SIZE,
- select_random_model,
- get_task,
- get_modality,
- get_output_media_type,
- MediaType,
- Modality
-)
-from bitmind.synthetic_data_generation.image_utils import create_random_mask, is_black_image
-from bitmind.synthetic_data_generation.prompt_utils import truncate_prompt_if_too_long
-from bitmind.synthetic_data_generation.prompt_generator import PromptGenerator
-from bitmind.validator.cache import ImageCache
-from bitmind.validator.model_utils import (
- load_hunyuanvideo_transformer,
- load_annimatediff_motion_adapter,
- JanusWrapper,
- create_pipeline_generator,
- enable_model_optimizations
-)
-
-
-future_warning_modules_to_ignore = [
- 'diffusers',
- 'transformers.tokenization_utils_base'
-]
-
-for module in future_warning_modules_to_ignore:
- warnings.filterwarnings("ignore", category=FutureWarning, module=module)
-
-torch.backends.cuda.matmul.allow_tf32 = True
-torch.backends.cudnn.allow_tf32 = True
-torch.set_float32_matmul_precision('high')
-
-os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
-os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
-
-
-class SyntheticDataGenerator:
- """
- A class for generating synthetic images and videos based on text prompts.
-
- This class supports different prompt generation strategies and can utilize
- various text-to-video (t2v) and text-to-image (t2i) models.
-
- Attributes:
- use_random_model: Whether to randomly select a t2v or t2i for each
- generation task.
- prompt_type: The type of prompt generation strategy ('random', 'annotation').
- prompt_generator_name: Name of the prompt generation model.
- model_name: Name of the t2v, t2i, or i2i model.
- prompt_generator: The vlm/llm pipeline for generating input prompts for t2i/t2v models
- output_dir: Directory to write generated data.
- """
-
- def __init__(
- self,
- model_name: Optional[str] = None,
- use_random_model: bool = True,
- prompt_type: str = 'annotation',
- output_dir: Optional[Union[str, Path]] = None,
- image_cache: Optional[ImageCache] = None,
- device: str = 'cuda'
- ) -> None:
- """
- Initialize the SyntheticDataGenerator.
-
- Args:
- model_name: Name of the generative image/video model
- use_random_model: Whether to randomly select models for generation.
- prompt_type: The type of prompt generation strategy.
- output_dir: Directory to write generated data.
- device: Device identifier.
- image_cache: Optional image cache instance.
-
- Raises:
- ValueError: If an invalid model name is provided.
- NotImplementedError: If an unsupported prompt type is specified.
- """
- if not use_random_model and model_name not in MODEL_NAMES:
- raise ValueError(
- f"Invalid model name '{model_name}'. "
- f"Options are {MODEL_NAMES}"
- )
-
- self.use_random_model = use_random_model
- self.model_name = model_name
- self.model = None
- self.device = device
-
- if self.use_random_model and model_name is not None:
- bt.logging.warning(
- "model_name will be ignored (use_random_model=True)"
- )
- self.model_name = None
-
- self.prompt_type = prompt_type
- self.image_cache = image_cache
- if self.prompt_type == 'annotation' and self.image_cache is None:
- raise ValueError(f"image_cache cannot be None if prompt_type == 'annotation'")
-
- self.prompt_generator = PromptGenerator(
- vlm_name=IMAGE_ANNOTATION_MODEL,
- llm_name=TEXT_MODERATION_MODEL
- )
-
- self.output_dir = Path(output_dir) if output_dir else None
- if self.output_dir:
- (self.output_dir / Modality.IMAGE / MediaType.SYNTHETIC).mkdir(parents=True, exist_ok=True)
- (self.output_dir / Modality.IMAGE / MediaType.SEMISYNTHETIC).mkdir(parents=True, exist_ok=True)
- (self.output_dir / Modality.VIDEO / MediaType.SYNTHETIC).mkdir(parents=True, exist_ok=True)
- (self.output_dir / Modality.VIDEO / MediaType.SEMISYNTHETIC).mkdir(parents=True, exist_ok=True)
-
- def batch_generate(self, batch_size: int = 5) -> None:
- """
- Asynchronously generate synthetic data in batches.
-
- This method handles the complete generation pipeline:
- 1. Samples source images for captioning
- 2. Determines which models to use (random or specific)
- 3. Groups models by type (video vs image)
- 4. Generates prompts by task type
- 5. Performs generation for each model using appropriate prompts
- 6. Saves outputs with metadata
-
- Args:
- batch_size: Number of prompts/generations to create per model
- """
- # Step 1: Sample source images to reuse across all generations
- images = []
- bt.logging.info(f"Starting batch generation with size {batch_size}")
- for i in range(batch_size):
- image_sample = self.image_cache.sample()
- images.append(image_sample['image'])
- bt.logging.info(f"Sampled image {i+1}/{batch_size} for captioning: {image_sample['path']}")
-
- # Step 2: Determine which models to use
- # Either use a single specified model or create a balanced random selection
- if not self.use_random_model and self.model_name:
- model_names = [self.model_name]
- else:
- # Shuffle each model type separately to maintain diversity
- # Then interleave them to distribute different types evenly
- i2i_model_names = random.sample(I2I_MODEL_NAMES, len(I2I_MODEL_NAMES))
- t2i_model_names = random.sample(T2I_MODEL_NAMES, len(T2I_MODEL_NAMES))
- t2v_model_names = random.sample(T2V_MODEL_NAMES, len(T2V_MODEL_NAMES))
- i2v_model_names = random.sample(I2V_MODEL_NAMES, len(I2V_MODEL_NAMES))
-
- # Interleave all model types to ensure variety in generation order
- model_names = [
- m for quad in zip_longest(t2v_model_names, t2i_model_names,
- i2i_model_names, i2v_model_names)
- for m in quad if m is not None
- ]
-
- bt.logging.info(f"Using {'random models' if self.use_random_model else f'specific model: {self.model_name}'}")
-
- # Step 3: Group models by type to minimize prompt generator reloading
- # Video models (t2v, i2v) need enhanced prompts with motion
- # Image models (t2i, i2i) use standard prompts
- video_models = [] # t2v, i2v models
- image_models = [] # t2i, i2i models
- for model_name in model_names:
- task = get_task(model_name)
- if task in ['t2v', 'i2v']:
- video_models.append((model_name, task))
- else:
- image_models.append((model_name, task))
-
- # Log model distribution for monitoring
- bt.logging.info(f"Model distribution:")
- bt.logging.info(f"- Video models ({len(video_models)}): {[m[0] for m in video_models]}")
- bt.logging.info(f"- Image models ({len(image_models)}): {[m[0] for m in image_models]}")
-
- # Step 4: Process each group (video/image) separately
- # This allows us to generate appropriate prompts once per type
- for is_video_task, model_group in [
- (True, video_models),
- (False, image_models)
- ]:
- if not model_group:
- bt.logging.info(f"Skipping {'video' if is_video_task else 'image'} generation - no models in group")
- continue
-
- # Generate all prompts for this task type at once
- prompts = []
- bt.logging.info(f"Generating {batch_size} {'video' if is_video_task else 'image'} prompts")
-
- try:
- for i in range(batch_size):
- # Use the first model's task type as reference for prompt generation
- reference_task = model_group[0][1] # Each entry is (model_name, task)
- prompts.append(self.generate_prompt(
- image=images[i],
- clear_gpu=False, # Keep prompt generator loaded until all prompts are done
- task=reference_task
- ))
- bt.logging.info(f"Caption {i+1}/{batch_size} generated: {prompts[-1]}")
- finally:
- # Ensure prompt generator is cleared from GPU
- self.prompt_generator.clear_gpu()
- torch.cuda.empty_cache()
- gc.collect()
-
- # Step 5: Generate outputs for each model using the prepared prompts
- for model_name, task in model_group:
- modality = get_modality(model_name)
- media_type = get_output_media_type(model_name)
-
- for i, prompt in enumerate(prompts):
- bt.logging.info(f"Started generation {i+1}/{batch_size} | Model: {model_name} | Prompt: {prompt}")
-
- # Generate the actual output (image/video)
- output = self._run_generation(prompt, task=task, model_name=model_name, image=images[i])
-
- # Step 6: Save the generated output and its metadata
- # Organize outputs by modality/type/model for easy retrieval
- model_output_dir = self.output_dir / modality / media_type / model_name.split('/')[1]
- model_output_dir.mkdir(parents=True, exist_ok=True)
- base_path = model_output_dir / str(output['time'])
-
- bt.logging.info(f'Writing to cache {model_output_dir}')
-
- # Save metadata separately from the actual output
- metadata = {k: v for k, v in output.items() if k != 'gen_output' and 'image' not in k}
- base_path.with_suffix('.json').write_text(json.dumps(metadata))
-
- # Save the output in appropriate format based on modality
- if modality == Modality.IMAGE:
- out_path = base_path.with_suffix('.png')
- output['gen_output'].images[0].save(out_path)
- elif modality == Modality.VIDEO:
- bt.logging.info("Writing to cache")
- out_path = str(base_path.with_suffix('.mp4'))
- export_to_video(output['gen_output'].frames[0], out_path, fps=30)
- bt.logging.info(f"Wrote to {out_path}")
-
- def generate(
- self,
- image: Optional[Image.Image] = None,
- task: Optional[str] = None,
- model_name: Optional[str] = None
- ) -> Dict[str, Any]:
- """
- Generate synthetic data based on input parameters.
-
- Args:
- image: Input image for annotation-based generation.
- modality: Type of media to generate ('image' or 'video').
-
- Returns:
- Dictionary containing generated data information.
-
- Raises:
- ValueError: If real_image is None when using annotation prompt type.
- NotImplementedError: If prompt type is not supported.
- """
- prompt = self.generate_prompt(image, clear_gpu=True, task=task)
- bt.logging.info("Generating synthetic data...")
- gen_data = self._run_generation(prompt, task, model_name, image)
- self.clear_gpu()
- return gen_data
-
- def generate_prompt(
- self,
- image: Optional[Image.Image] = None,
- clear_gpu: bool = True,
- task: Optional[str] = None
- ) -> str:
- """Generate a prompt based on the specified strategy."""
- bt.logging.info("Generating prompt")
- if self.prompt_type == 'annotation':
- if image is None:
- raise ValueError(
- "image can't be None if self.prompt_type is 'annotation'"
- )
- self.prompt_generator.load_models()
- prompt = self.prompt_generator.generate(image, task=task)
- if clear_gpu:
- self.prompt_generator.clear_gpu()
- else:
- raise NotImplementedError(f"Unsupported prompt type: {self.prompt_type}")
- return prompt
-
- def _run_generation(
- self,
- prompt: str,
- task: Optional[str] = None,
- model_name: Optional[str] = None,
- image: Optional[Image.Image] = None,
- generate_at_target_size: bool = False,
- ) -> Dict[str, Any]:
- """
- Generate synthetic data based on a text prompt.
-
- Args:
- prompt: The text prompt used to inspire the generation.
- task: The generation task type ('t2i', 't2v', 'i2i', 'i2v', or None).
- model_name: Optional model name to use for generation.
- image: Optional input image for image-to-image or image-to-video generation.
- generate_at_target_size: If True, generate at TARGET_IMAGE_SIZE dimensions.
-
- Returns:
- Dictionary containing generated data and metadata.
-
- Raises:
- RuntimeError: If generation fails.
- """
- # Clear CUDA cache before loading model
- torch.cuda.empty_cache()
- gc.collect()
-
- self.load_model(model_name)
- model_config = MODELS[self.model_name]
- task = get_task(model_name) if task is None else task
-
- bt.logging.info("Preparing generation arguments")
- gen_args = model_config.get('generate_args', {}).copy()
- mask_center = None
-
- # prep inpainting-specific generation args
- if task == 'i2i':
- # Use larger image size for better inpainting quality
- target_size = (1024, 1024)
- if image.size[0] > target_size[0] or image.size[1] > target_size[1]:
- image = image.resize(target_size, Image.Resampling.LANCZOS)
-
- gen_args['mask_image'], mask_center = create_random_mask(image.size)
- gen_args['image'] = image
- # prep image-to-video generation args
- elif task == 'i2v':
- if image is None:
- raise ValueError("image cannot be None for image-to-video generation")
- # Get target size from gen_args if specified, otherwise use default
- target_size = (
- gen_args.get('height', 768),
- gen_args.get('width', 768)
- )
- if image.size[0] > target_size[0] or image.size[1] > target_size[1]:
- image = image.resize(target_size, Image.Resampling.LANCZOS)
- gen_args['image'] = image
-
- # Prepare generation arguments
- for k, v in gen_args.items():
- if isinstance(v, dict):
- if "min" in v and "max" in v:
- # For i2v, use minimum values to save memory
- if task == 'i2v':
- gen_args[k] = v['min']
- else:
- gen_args[k] = np.random.randint(v['min'], v['max'])
- if "options" in v:
- gen_args[k] = random.choice(v['options'])
- # Ensure num_frames is always an integer
- if k == 'num_frames' and isinstance(v, dict):
- if "min" in v:
- gen_args[k] = v['min']
- elif "max" in v:
- gen_args[k] = v['max']
- else:
- gen_args[k] = 24 # Default value
-
- try:
- if generate_at_target_size:
- gen_args['height'] = TARGET_IMAGE_SIZE[0]
- gen_args['width'] = TARGET_IMAGE_SIZE[1]
- elif 'resolution' in gen_args:
- gen_args['height'] = gen_args['resolution'][0]
- gen_args['width'] = gen_args['resolution'][1]
- del gen_args['resolution']
-
- # Ensure num_frames is an integer before generation
- if 'num_frames' in gen_args:
- gen_args['num_frames'] = int(gen_args['num_frames'])
-
- truncated_prompt = truncate_prompt_if_too_long(prompt, self.model)
- bt.logging.info(f"Generating media from prompt: {truncated_prompt}")
- bt.logging.info(f"Generation args: {gen_args}")
-
- start_time = time.time()
-
- # Create pipeline-specific generator
- generate = create_pipeline_generator(model_config, self.model)
-
- # Handle autocast if needed
- if model_config.get('use_autocast', True):
- pretrained_args = model_config.get('from_pretrained_args', {})
- torch_dtype = pretrained_args.get('torch_dtype', torch.bfloat16)
- with torch.autocast(self.device, torch_dtype, cache_enabled=False):
- # Clear CUDA cache before generation
- torch.cuda.empty_cache()
- gc.collect()
- gen_output = generate(truncated_prompt, **gen_args)
- else:
- # Clear CUDA cache before generation
- torch.cuda.empty_cache()
- gc.collect()
- gen_output = generate(truncated_prompt, **gen_args)
-
- gen_time = time.time() - start_time
-
- except Exception as e:
- if generate_at_target_size:
- bt.logging.error(
- f"Attempt with custom dimensions failed, falling back to "
- f"default dimensions. Error: {e}"
- )
- try:
- # Clear CUDA cache before retry
- torch.cuda.empty_cache()
- gen_output = self.model(prompt=truncated_prompt)
- gen_time = time.time() - start_time
- except Exception as fallback_error:
- bt.logging.error(
- f"Failed to generate image with default dimensions after "
- f"initial failure: {fallback_error}"
- )
- raise RuntimeError(
- f"Both attempts to generate image failed: {fallback_error}"
- )
- else:
- bt.logging.error(f"Image generation error: {e}")
- raise RuntimeError(f"Failed to generate image: {e}")
-
- if self.model_name == "DeepFloyd/IF":
- max_retries = 3
- attempt = 0
- while attempt < max_retries:
- img = gen_output.images[0]
- if is_black_image(img):
- bt.logging.warning("DeepFloyd/IF returned a black image (likely NSFW). Attempting to sanitize prompt and retry.")
- # Ensure prompt generator models are loaded
- self.prompt_generator.load_llm()
- # Sanitize the prompt
- prompt = self.prompt_generator.sanitize_prompt(prompt)
- truncated_prompt = truncate_prompt_if_too_long(prompt, self.model)
- bt.logging.info(f"Sanitized prompt: {truncated_prompt}")
- self.prompt_generator.clear_gpu()
- try:
- gen_output = generate(truncated_prompt, **gen_args)
- except Exception as e:
- bt.logging.error(f"Sanitized prompt generation failed: {e}")
- break
- attempt += 1
- else:
- break
-
- print(f"Finished generation in {gen_time/60} minutes")
- return {
- 'prompt': truncated_prompt,
- 'prompt_long': prompt,
- 'gen_output': gen_output, # image or video
- 'time': time.time(),
- 'model_name': self.model_name,
- 'gen_time': gen_time,
- 'mask_image': gen_args.get('mask_image', None),
- 'mask_center': mask_center,
- 'image': gen_args.get('image', None)
- }
-
- def load_model(self, model_name: Optional[str] = None, modality: Optional[str] = None) -> None:
- """Load a Hugging Face text-to-image or text-to-video model."""
- if model_name is not None:
- self.model_name = model_name
- elif self.use_random_model or model_name == 'random':
- self.model_name = select_random_model(modality)
-
- bt.logging.info(f"Loading {self.model_name}")
-
- model_config = MODELS[self.model_name]
- pipeline_cls = model_config['pipeline_cls']
- pipeline_args = model_config.get('from_pretrained_args', {}).copy()
-
- # Handle custom loading functions passed as tuples
- for k, v in pipeline_args.items():
- if isinstance(v, tuple) and callable(v[0]):
- pipeline_args[k] = v[0](**v[1])
-
- # Get model_id if specified, otherwise use model_name
- model_id = pipeline_args.pop('model_id', self.model_name)
-
- # Handle multi-stage pipeline
- if isinstance(pipeline_cls, dict):
- self.model = {}
- for stage_name, stage_cls in pipeline_cls.items():
- stage_args = pipeline_args.get(stage_name, {})
- base_model = stage_args.get('base', model_id)
- stage_args_filtered = {k:v for k,v in stage_args.items() if k != 'base'}
-
- bt.logging.info(f"Loading {stage_name} from {base_model}")
- self.model[stage_name] = stage_cls.from_pretrained(
- base_model,
- cache_dir=HUGGINGFACE_CACHE_DIR,
- **stage_args_filtered,
- add_watermarker=False
- )
-
- enable_model_optimizations(
- model=self.model[stage_name],
- device=self.device,
- enable_cpu_offload=model_config.get('enable_model_cpu_offload', False),
- enable_sequential_cpu_offload=model_config.get('enable_sequential_cpu_offload', False),
- enable_vae_slicing=model_config.get('vae_enable_slicing', False),
- enable_vae_tiling=model_config.get('vae_enable_tiling', False),
- stage_name=stage_name
- )
-
- # Disable watermarker
- self.model[stage_name].watermarker = None
- else:
- # Single-stage pipeline
- self.model = pipeline_cls.from_pretrained(
- model_id,
- cache_dir=HUGGINGFACE_CACHE_DIR,
- **pipeline_args,
- add_watermarker=False
- )
-
- # Load LoRA weights if specified
- if 'lora_model_id' in model_config:
- bt.logging.info(f"Loading LoRA weights from {model_config['lora_model_id']}")
- lora_loading_args = model_config.get('lora_loading_args', {})
- self.model.load_lora_weights(
- model_config['lora_model_id'],
- **lora_loading_args
- )
-
- # Load scheduler if specified
- if 'scheduler' in model_config:
- sched_cls = model_config['scheduler']['cls']
- sched_args = model_config['scheduler'].get('from_config_args', {})
- self.model.scheduler = sched_cls.from_config(
- self.model.scheduler.config,
- **sched_args
- )
-
- enable_model_optimizations(
- model=self.model,
- device=self.device,
- enable_cpu_offload=model_config.get('enable_model_cpu_offload', False),
- enable_sequential_cpu_offload=model_config.get('enable_sequential_cpu_offload', False),
- enable_vae_slicing=model_config.get('vae_enable_slicing', False),
- enable_vae_tiling=model_config.get('vae_enable_tiling', False)
- )
-
- # Disable watermarker
- self.model.watermarker = None
-
- bt.logging.info(f"Loaded {self.model_name}")
-
- def clear_gpu(self) -> None:
- """Clear GPU memory by deleting models and running garbage collection."""
- if self.model is not None:
- bt.logging.info(
- "Deleting previous text-to-image or text-to-video model, "
- "freeing memory"
- )
- del self.model
- self.model = None
- gc.collect()
- torch.cuda.empty_cache()
\ No newline at end of file
diff --git a/bitmind/transforms.py b/bitmind/transforms.py
new file mode 100644
index 00000000..8aa6fdc3
--- /dev/null
+++ b/bitmind/transforms.py
@@ -0,0 +1,605 @@
+import math
+import random
+from scipy import ndimage
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+import numpy as np
+import cv2
+
+
+def apply_random_augmentations(
+ inputs, target_image_size, mask_point=None, level_probs=None
+):
+ """
+ Apply image transformations based on randomly selected level.
+
+ Args:
+ image: image, video, or tuple of videos to transform
+ level_probs: dict with augmentation levels and their probabilities.
+ Default probabilities:
+ - Level 0 (25%): No augmentations (base transforms)
+ - Level 1 (45%): Basic augmentations
+ - Level 2 (15%): Medium distortions
+ - Level 3 (15%): Hard distortions
+
+ Returns:
+ tuple: (transformed_image, level, transform_params)
+
+ Raises:
+ ValueError: If probabilities don't sum to 1.0 (within floating point precision)
+ """
+ if level_probs is None:
+ level_probs = { # Difficult level probabilities
+ 0: 0.25, # No augmentations (base transforms)
+ 1: 0.25, # Basic augmentations
+ 2: 0.25, # Medium distortions
+ 3: 0.25, # Hard distortions
+ }
+
+ if not math.isclose(sum(level_probs.values()), 1.0, rel_tol=1e-9):
+ raise ValueError("Probabilities of levels must sum to 1.0")
+
+ # get cumulative probs and select augmentation level
+ cumulative_probs = {}
+ cumsum = 0
+ for level, prob in sorted(level_probs.items()):
+ cumsum += prob
+ cumulative_probs[level] = cumsum
+
+ rand_val = np.random.random()
+ for curr_level, cum_prob in cumulative_probs.items():
+ if rand_val <= cum_prob:
+ level = curr_level
+ break
+
+ if level == 0:
+ tforms = get_base_transforms(target_image_size)
+ elif level == 1:
+ tforms = get_random_augmentations(target_image_size, mask_point)
+ elif level == 2:
+ tforms = get_random_augmentations_medium(target_image_size, mask_point)
+ else: # level == 3
+ tforms = get_random_augmentations_hard(target_image_size, mask_point)
+
+ if isinstance(inputs, tuple):
+ transformed_A = tforms(inputs[0], clear_params=False)
+ transformed_B = tforms(inputs[1])
+ transformed = np.concatenate([transformed_A, transformed_B], axis=0)
+ else:
+ transformed = tforms(inputs)
+
+ return transformed, level, tforms.params
+
+
+def get_base_transforms(target_image_size):
+ return ComposeWithParams(
+ [
+ CenterCrop(),
+ Resize(),
+ ]
+ )
+
+
+def get_random_augmentations(target_image_size, mask_point=None):
+ """Basic augmentations with geometric transforms"""
+ base_augmentations = [
+ RandomRotationWithParams(degrees=20, order=2),
+ RandomResizedCropWithParams(
+ target_image_size, scale=(0.2, 1.0), include_point=mask_point
+ ),
+ RandomHorizontalFlipWithParams(),
+ RandomVerticalFlipWithParams(),
+ ]
+ return ComposeWithParams(base_augmentations)
+
+
+def get_random_augmentations_medium(target_image_size, mask_point=None):
+ """Medium difficulty transforms with mild distortions"""
+ base_augmentations = get_random_augmentations(target_image_size, mask_point)
+
+ distortions = [
+ ApplyDeeperForensicsDistortion("CS", level_min=0, level_max=1),
+ ApplyDeeperForensicsDistortion("CC", level_min=0, level_max=1),
+ ApplyDeeperForensicsDistortion("JPEG", level_min=0, level_max=1),
+ ]
+
+ return ComposeWithParams(base_augmentations.transforms + distortions)
+
+
+def get_random_augmentations_hard(target_image_size, mask_point=None):
+ """Hard difficulty transforms with more severe distortions"""
+ base_augmentations = get_random_augmentations(target_image_size, mask_point)
+
+ distortions = [
+ ApplyDeeperForensicsDistortion("CS", level_min=0, level_max=2),
+ ApplyDeeperForensicsDistortion("CC", level_min=0, level_max=2),
+ ApplyDeeperForensicsDistortion("JPEG", level_min=0, level_max=2),
+ ApplyDeeperForensicsDistortion("GNC", level_min=0, level_max=2),
+ ApplyDeeperForensicsDistortion("GB", level_min=0, level_max=2),
+ ]
+
+ return ComposeWithParams(base_augmentations.transforms + distortions)
+
+
+class ComposeWithParams:
+ def __init__(self, transforms):
+ self.transforms = transforms
+ self.params = {}
+
+ def __call__(self, input_data, clear_params=True):
+ if clear_params:
+ self.params = {}
+
+ is_single_image = input_data.ndim == 3 # (H, W, C)
+ if is_single_image:
+ input_data = input_data[None, ...] # Add fake temporal dim → (1, H, W, C)
+
+ output_frames = []
+ for frame in input_data:
+ for transform in self.transforms:
+ name = getattr(transform, "__name__", transform.__class__.__name__)
+ if name in self.params:
+ frame = transform(frame, **self.params[name])
+ else:
+ frame = transform(frame)
+ if hasattr(transform, "params"):
+ self.params[name] = transform.params
+ output_frames.append(frame)
+
+ output = np.stack(output_frames)
+ return output[0] if is_single_image else output
+
+
+class ApplyDeeperForensicsDistortion:
+ """Wrapper for applying DeeperForensics distortions."""
+
+ def __init__(self, distortion_type, level_min=0, level_max=3):
+ self.__name__ = distortion_type
+ self.distortion_type = distortion_type
+ self.level = None
+ self.level_min = level_min
+ self.level_max = level_max
+ self.params = {} # level
+ self.distortion_params = {} # distortion_type specific
+
+ def __call__(self, img, level=None):
+ if level is None and self.level is None:
+ self.level = random.randint(self.level_min, self.level_max)
+ self.params = {"level": self.level}
+ elif self.level is None:
+ self.level = level
+ self.params = {"level": self.level}
+
+ if self.level > 0:
+ self.distortion_func = get_distortion_function(self.distortion_type)
+ if len(self.distortion_params) == 0:
+ self.distortion_param = get_distortion_parameter(
+ self.distortion_type, self.level
+ )
+ self.distortion_params = {"param": self.distortion_param}
+ else:
+ return img
+
+ output = self.distortion_func(img, **self.distortion_params)
+ if isinstance(output, tuple):
+ self.distortion_params.update(output[1])
+ return output[0]
+ else:
+ return output
+
+
+# DeeperForensics Distortion Functions
+def get_distortion_parameter(distortion_type, level):
+ """Get distortion parameter based on type and level.
+
+ Parameters are arranged from least severe (level 1) to most severe (level 5).
+ Each distortion type has different parameter behavior:
+
+ CS (Color Saturation):
+ - Range: [0.4 -> 0.0]
+ - Lower values = worse distortion
+ - 0.4 = slight desaturation
+ - 0.0 = complete desaturation (grayscale)
+
+ CC (Color Contrast):
+ - Range: [0.85 -> 0.35]
+ - Lower values = worse distortion
+ - 0.85 = slight contrast reduction
+ - 0.35 = severe contrast reduction
+
+ BW (Block Wise):
+ - Range: [16 -> 80]
+ - Higher values = worse distortion
+ - Controls number of random blocks added
+ - 16 = few blocks
+ - 80 = many blocks
+
+ GNC (Gaussian Noise Color):
+ - Range: [0.001 -> 0.05]
+ - Higher values = worse distortion
+ - Controls noise variance
+ - 0.001 = subtle noise
+ - 0.05 = very noisy
+
+ GB (Gaussian Blur):
+ - Range: [7 -> 21]
+ - Higher values = worse distortion
+ - Controls blur kernel size
+ - 7 = slight blur
+ - 21 = heavy blur
+
+ JPEG (JPEG Compression):
+ - Range: [2 -> 6]
+ - Higher values = worse distortion
+ - Controls downsampling factor
+ - 2 = mild compression
+ - 6 = severe compression
+ """
+ param_dict = {
+ "CS": [0.4, 0.3, 0.2, 0.1, 0.0],
+ "CC": [0.85, 0.725, 0.6, 0.475, 0.35],
+ "BW": [16, 32, 48, 64, 80],
+ "GNC": [0.001, 0.002, 0.005, 0.01, 0.05],
+ "GB": [7, 9, 13, 17, 21],
+ "JPEG": [2, 3, 4, 5, 6],
+ }
+ return param_dict[distortion_type][level - 1]
+
+
+def get_distortion_function(distortion_type):
+ """Get distortion function based on type."""
+ func_dict = {
+ "CS": color_saturation,
+ "CC": color_contrast,
+ "BW": block_wise,
+ "GNC": gaussian_noise_color,
+ "GB": gaussian_blur,
+ "JPEG": jpeg_compression,
+ }
+ return func_dict[distortion_type]
+
+
+def Resize():
+ def resize(img):
+ return cv2.resize(img, (256, 256))
+
+ return resize
+
+
+def rgb2ycbcr(img_rgb):
+ """Convert RGB image to YCbCr color space.
+
+ Args:
+ img_rgb (np.ndarray): RGB image array of shape (H, W, 3)
+
+ Returns:
+ np.ndarray: YCbCr image array of shape (H, W, 3) with values normalized to [0,1]
+ """
+ img_rgb = img_rgb.astype(np.float32)
+ img_ycrcb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2YCR_CB)
+ img_ycbcr = img_ycrcb[:, :, (0, 2, 1)].astype(np.float32)
+ img_ycbcr[:, :, 0] = (img_ycbcr[:, :, 0] * (235 - 16) + 16) / 255.0
+ img_ycbcr[:, :, 1:] = (img_ycbcr[:, :, 1:] * (240 - 16) + 16) / 255.0
+ return img_ycbcr
+
+
+def ycbcr2rgb(img_ycbcr):
+ """Convert YCbCr image to RGB color space.
+
+ Args:
+ img_ycbcr (np.ndarray): YCbCr image array of shape (H, W, 3)
+
+ Returns:
+ np.ndarray: RGB image array of shape (H, W, 3) with values in [0,255]
+ """
+ img_ycbcr = img_ycbcr.astype(np.float32)
+ img_ycbcr[:, :, 0] = (img_ycbcr[:, :, 0] * 255.0 - 16) / (235 - 16)
+ img_ycbcr[:, :, 1:] = (img_ycbcr[:, :, 1:] * 255.0 - 16) / (240 - 16)
+ img_ycrcb = img_ycbcr[:, :, (0, 2, 1)].astype(np.float32)
+ img_rgb = cv2.cvtColor(img_ycrcb, cv2.COLOR_YCR_CB2RGB)
+ return img_rgb
+
+
+def color_saturation(img, param):
+ """Apply color saturation distortion.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+ param (float): Saturation multiplier parameter
+
+ Returns:
+ np.ndarray: Distorted RGB image array with modified saturation
+ """
+ ycbcr = rgb2ycbcr(img)
+ ycbcr[:, :, 1] = 0.5 + (ycbcr[:, :, 1] - 0.5) * param
+ ycbcr[:, :, 2] = 0.5 + (ycbcr[:, :, 2] - 0.5) * param
+ img = ycbcr2rgb(ycbcr).astype(np.uint8)
+ return img
+
+
+def color_contrast(img, param):
+ """Apply color contrast distortion.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+ param (float): Contrast multiplier parameter
+
+ Returns:
+ np.ndarray: Distorted RGB image array with modified contrast
+ """
+ img = img.astype(np.float32) * param
+ return img.astype(np.uint8)
+
+
+def block_wise(img, param):
+ """Apply block-wise distortion by adding random gray blocks.
+
+ NOTE: CURRENTLY NOT USED
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+ param (int): Number of blocks to add, scaled by image dimensions
+
+ Returns:
+ np.ndarray: Distorted RGB image array with added gray blocks
+ """
+ width = 8
+ block = np.ones((width, width, 3)).astype(int) * 128
+ param = min(img.shape[0], img.shape[1]) // 256 * param
+ for _ in range(param):
+ r_w = random.randint(0, img.shape[1] - 1 - width)
+ r_h = random.randint(0, img.shape[0] - 1 - width)
+ img[r_h : r_h + width, r_w : r_w + width, :] = block
+ return img
+
+
+def gaussian_noise_color(img, param, b=None):
+ """Apply colored Gaussian noise in YCbCr color space.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+ param (float): Variance of the Gaussian noise
+
+ Returns:
+ np.ndarray: Distorted RGB image array with added color noise
+ """
+ ycbcr = rgb2ycbcr(img) / 255
+ size_a = ycbcr.shape
+ if b is None:
+ b = (
+ ycbcr + math.sqrt(param) * np.random.randn(size_a[0], size_a[1], size_a[2])
+ ) * 255
+ b = ycbcr2rgb(b)
+ return np.clip(b, 0, 255).astype(np.uint8), {"b": b}
+
+
+def gaussian_blur(img, param):
+ """Apply Gaussian blur with specified kernel size.
+
+ Args
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+ param (int): Gaussian kernel size (must be odd)
+
+ Returns:
+ np.ndarray: Blurred RGB image array
+ """
+ return cv2.GaussianBlur(img, (param, param), param * 1.0 / 6)
+
+
+def jpeg_compression(img, param):
+ """Apply JPEG compression-like distortion through downsampling.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+ param (int): Downsampling factor
+
+ Returns:
+ np.ndarray: Distorted RGB image array with compression artifacts
+ """
+ h, w, _ = img.shape
+ s_h = h // param
+ s_w = w // param
+ img = cv2.resize(img, (s_w, s_h))
+ return cv2.resize(img, (w, h))
+
+
+def CenterCrop():
+ """Center crop an image to a square.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, 3)
+
+ Returns:
+ np.ndarray: Center cropped RGB image array with equal height and width
+ """
+
+ def crop(img):
+ h, w = img.shape[:2]
+ m = min(h, w)
+ i = (h - m) // 2
+ j = (w - m) // 2
+ return img[i : i + m, j : j + m]
+
+ return crop
+
+
+class RandomResizedCropWithParams:
+ """Randomly crop and resize an image while optionally preserving a point.
+
+ Args:
+ scale (tuple): Range of size of the origin size cropped
+ size (int or tuple): Target output size
+ include_point (tuple, optional): (x,y) point coordinates that must be preserved in crop
+ """
+
+ def __init__(self, size, scale, include_point=None):
+ self.params = None
+ self.scale = scale
+ self.size = size
+ self.include_point = include_point
+
+ def __call__(self, img, crop_params=None):
+ """Perform random resized crop transform.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, C)
+ crop_params (tuple, optional): Pre-computed crop parameters (i, j, h, w)
+ where image will be cropped to [i:i+h, j:j+w] before resizing
+
+ Returns:
+ np.ndarray: Randomly cropped and resized RGB image array
+ """
+ # Convert numpy array to shape expected by parent class
+ height, width = img.shape[:2]
+
+ if crop_params is None:
+ area = height * width
+ target_area = area * np.random.uniform(*self.scale)
+ h = w = int(round(np.sqrt(target_area)))
+ h = min(h, height)
+ w = min(w, width)
+ i = np.random.randint(0, height - h + 1)
+ j = np.random.randint(0, width - w + 1)
+ if self.include_point is not None:
+ x, y = self.include_point
+
+ # adjust crop to keep mask point
+ if x < j:
+ j = max(0, x - 10)
+ elif x > j + w:
+ j = min(width - w, x - w + 10)
+
+ if y < i:
+ i = max(0, y - 10)
+ elif y > i + h:
+ i = min(height - h, y - h + 10)
+ else:
+ i, j, h, w = crop_params
+
+ self.params = {"crop_params": (i, j, h, w)}
+ cropped = img[i : i + h, j : j + w, :]
+
+ if isinstance(self.size, int):
+ size = (self.size, self.size)
+ else:
+ size = self.size
+ resized = cv2.resize(cropped, size, interpolation=cv2.INTER_LINEAR)
+ return resized
+
+
+class RandomHorizontalFlipWithParams:
+ """Randomly flip an image horizontally.
+
+ Args:
+ p (float): Probability of flipping the image
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+ self.params = {}
+
+ def __call__(self, img, flip=None):
+ """Perform horizontal flip transform.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, C)
+ flip (bool, optional): Pre-computed flip decision
+
+ Returns:
+ np.ndarray: Horizontally flipped RGB image array if flip is True
+ """
+ if flip is not None:
+ self.params = {"flip": flip}
+ return np.fliplr(img) if flip else img
+ elif not hasattr(self, "params") or len(self.params) == 0:
+ flip = np.random.random() < self.p
+ self.params = {"flip": flip}
+ return np.fliplr(img) if flip else img
+ else:
+ return np.fliplr(img) if self.params.get("flip", False) else img
+
+
+class RandomVerticalFlipWithParams:
+ """Randomly flip an image vertically.
+
+ Args:
+ p (float): Probability of flipping the image
+ """
+
+ def __init__(self, p=0.5):
+ self.p = p
+ self.params = {}
+
+ def __call__(self, img, flip=None):
+ """Perform vertical flip transform.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, C)
+ flip (bool, optional): Pre-computed flip decision
+
+ Returns:
+ np.ndarray: Vertically flipped RGB image array if flip is True
+ """
+ if flip is not None:
+ self.params = {"flip": flip}
+ return np.flipud(img) if flip else img
+ elif not hasattr(self, "params") or len(self.params) == 0:
+ flip = np.random.random() < self.p
+ self.params = {"flip": flip}
+ return np.flipud(img) if flip else img
+ else:
+ return np.flipud(img) if self.params.get("flip", False) else img
+
+
+class RandomRotationWithParams:
+ """Randomly rotate an image.
+
+ Args:
+ degrees (float or tuple): Range of degrees to select from. If float, uses (-degrees, degrees)
+ p (float): Probability of rotating the image
+ reshape (bool): If True, expands output image to fit rotated image
+ mode (str): How to fill the border ('reflect', 'constant', etc)
+ order (int): Interpolation order (0-5)
+ """
+
+ def __init__(self, degrees, p=0.5, reshape=False, mode="reflect", order=2):
+ if isinstance(degrees, (tuple, list)):
+ self.degrees = degrees
+ else:
+ self.degrees = (-degrees, degrees)
+ self.p = p
+ self.params = None
+ self.reshape = reshape
+ self.mode = mode
+ self.order = order
+
+ def __call__(self, img, rotate=None, angle=None, order=None):
+ """Perform rotation transform.
+
+ Args:
+ img (np.ndarray): Input RGB image array of shape (H, W, C)
+ rotate (bool, optional): Pre-computed rotation decision
+ angle (float, optional): Pre-computed rotation angle
+ order (int, optional): Pre-computed interpolation order
+
+ Returns:
+ np.ndarray: Rotated RGB image array
+ """
+ if rotate is None:
+ rotate = np.random.random() < self.p
+ self.params = {"rotate": rotate}
+
+ if not rotate:
+ return img
+
+ order = self.order if order is None else order
+ if isinstance(order, (tuple, list)):
+ order = random.randint(order[0], order[1])
+
+ if angle is None:
+ angle = random.uniform(self.degrees[0], self.degrees[1])
+
+ self.params.update({"order": order, "angle": angle})
+ return ndimage.rotate(
+ img, angle, reshape=self.reshape, mode=self.mode, order=order, axes=(0, 1)
+ )
diff --git a/bitmind/types.py b/bitmind/types.py
new file mode 100644
index 00000000..091d547d
--- /dev/null
+++ b/bitmind/types.py
@@ -0,0 +1,191 @@
+from pathlib import Path
+from dataclasses import dataclass, field
+from enum import Enum, auto
+from pydantic import BaseModel
+from typing import Dict, List, Any, Optional, Union
+
+
+class NeuronType(Enum):
+ VALIDATOR = "VALIDATOR"
+ VALIDATOR_PROXY = "VALIDATOR_PROXY"
+ MINER = "MINER"
+
+
+class FileType(Enum):
+ PARQUET = auto()
+ ZIP = auto()
+ VIDEO = auto()
+ IMAGE = auto()
+
+
+class CacheType(str, Enum):
+ MEDIA = "media"
+ COMPRESSED = "compressed"
+
+
+class Modality(str, Enum):
+ IMAGE = "image"
+ VIDEO = "video"
+
+
+class MediaType(str, Enum):
+ REAL = "real", 0
+ SYNTHETIC = "synthetic", 1
+ SEMISYNTHETIC = "semisynthetic", 2
+
+ def __new__(cls, str_value, int_value):
+ obj = str.__new__(cls, str_value)
+ obj._value_ = str_value
+ obj.int_value = int_value
+ return obj
+
+
+@dataclass
+class CacheUpdaterConfig:
+ num_sources_per_dataset: int = 1
+ num_items_per_source: int = 100
+
+
+@dataclass
+class CacheConfig:
+ """Configuration for a cache at base_dir / {modality} / {media_type}"""
+
+ modality: str
+ media_type: str
+ base_dir: Path = Path("~/.cache/sn34").expanduser()
+ tags: Optional[List[str]] = None
+ max_compressed_gb: float = 100.0
+ max_media_gb: float = 10.0
+
+ def get_path(self):
+ media_cache_path = Path(self.base_dir) / self.modality / self.media_type
+ media_cache_path.mkdir(exist_ok=True, parents=True)
+ return media_cache_path
+
+
+@dataclass
+class DatasetConfig:
+ path: str # HuggingFace path
+ type: Modality
+ media_type: MediaType
+ tags: List[str] = field(default_factory=list)
+ file_format: str = ""
+ compressed_format: str = ""
+ priority: int = 1 # Optional: priority for sampling (higher is more frequent)
+ enabled: bool = True
+
+ def __post_init__(self):
+ """Validate and set defaults"""
+ if not self.compressed_format:
+ if self.type == Modality.IMAGE:
+ self.compressed_format = "parquet"
+ elif self.type == Modality.VIDEO:
+ self.compressed_format = "zip"
+
+ if isinstance(self.tags, str):
+ self.tags = [t.strip() for t in self.tags.split(",")]
+
+ if isinstance(self.type, str):
+ self.type = Modality(self.type.lower())
+
+ if isinstance(self.media_type, str):
+ self.media_type = MediaType(self.media_type.lower())
+
+
+class ModelTask(str, Enum):
+ """Type of task the model is designed for"""
+
+ TEXT_TO_IMAGE = "t2i"
+ TEXT_TO_VIDEO = "t2v"
+ IMAGE_TO_IMAGE = "i2i"
+ IMAGE_TO_VIDEO = "i2v"
+
+
+class ModelConfig:
+ """
+ Configuration for a generative AI model.
+
+ Attributes:
+ path: The Hugging Face model path or identifier
+ task: The primary task of the model (T2I, T2V, I2I)
+ media_type: Type of output (synthetic or semisynthetic)
+ pipeline_cls: Pipeline class used to load the model
+ pretrained_args: Arguments for the from_pretrained method
+ generate_args: Default arguments for generation
+ tags: List of tags for categorizing the model
+ use_autocast: Whether to use autocast during generation
+ scheduler: Optional scheduler configuration
+ scheduler_cls: Optional scheduler class
+ scheduler_args: Optional scheduler args
+ """
+
+ def __init__(
+ self,
+ path: str,
+ task: ModelTask,
+ pipeline_cls: Union[Any, Dict[str, Any]],
+ media_type: Optional[MediaType] = None,
+ pretrained_args: Dict[str, Any] = None,
+ generate_args: Dict[str, Any] = None,
+ tags: List[str] = None,
+ use_autocast: bool = True,
+ enable_model_cpu_offload: bool = False,
+ enable_sequential_cpu_offload: bool = False,
+ vae_enable_slicing: bool = False,
+ vae_enable_tiling: bool = False,
+ scheduler: Dict[str, Any] = None,
+ save_args: Dict[str, Any] = None,
+ pipeline_stages: List[Dict[str, Any]] = None,
+ clear_memory_on_stage_end: bool = False,
+ lora_model_id: str = None,
+ lora_loading_args: Dict[str, Any] = None,
+ ):
+ self.path = path
+ self.task = task
+ self.pipeline_cls = pipeline_cls
+ self.media_type = media_type
+
+ if self.media_type is None:
+ self.media_type = (
+ MediaType.SEMISYNTHETIC
+ if task == ModelTask.IMAGE_TO_IMAGE
+ else MediaType.SYNTHETIC
+ )
+
+ self.pretrained_args = pretrained_args or {}
+ self.generate_args = generate_args or {}
+ self.tags = tags or []
+ self.use_autocast = use_autocast
+ self.enable_model_cpu_offload = enable_model_cpu_offload
+ self.enable_sequential_cpu_offload = enable_sequential_cpu_offload
+ self.vae_enable_slicing = vae_enable_slicing
+ self.vae_enable_tiling = vae_enable_tiling
+ self.scheduler = scheduler
+ self.save_args = save_args or {}
+ self.pipeline_stages = pipeline_stages
+ self.clear_memory_on_stage_end = clear_memory_on_stage_end
+ self.lora_model_id = lora_model_id
+ self.lora_loading_args = lora_loading_args
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert config to dictionary format"""
+ return {
+ "pipeline_cls": self.pipeline_cls,
+ "from_pretrained_args": self.pretrained_args,
+ "generate_args": self.generate_args,
+ "use_autocast": self.use_autocast,
+ "enable_model_cpu_offload": self.enable_model_cpu_offload,
+ "enable_sequential_cpu_offload": self.enable_sequential_cpu_offload,
+ "vae_enable_slicing": self.vae_enable_slicing,
+ "vae_enable_tiling": self.vae_enable_tiling,
+ "scheduler": self.scheduler,
+ "save_args": self.save_args,
+ "pipeline_stages": self.pipeline_stages,
+ "clear_memory_on_stage_end": self.clear_memory_on_stage_end,
+ }
+
+
+class ValidatorConfig(BaseModel):
+ skip_weight_set: Optional[bool] = False
+ set_weights_on_start: Optional[bool] = False
+ max_concurrent_organics: Optional[int] = 2
diff --git a/bitmind/utils.py b/bitmind/utils.py
new file mode 100644
index 00000000..b2e35939
--- /dev/null
+++ b/bitmind/utils.py
@@ -0,0 +1,108 @@
+import traceback
+import bittensor as bt
+import functools
+import json
+import os
+
+
+def print_info(metagraph, hotkey, block, isMiner=True):
+ uid = metagraph.hotkeys.index(hotkey)
+ log = f"UID:{uid} | Block:{block} | Consensus:{metagraph.C[uid]} | "
+ if isMiner:
+ bt.logging.info(
+ log
+ + f"Stake:{metagraph.S[uid]} | Trust:{metagraph.T[uid]} | Incentive:{metagraph.I[uid]} | Emission:{metagraph.E[uid]}"
+ )
+ return
+ bt.logging.info(log + f"VTrust:{metagraph.Tv[uid]} | ")
+
+
+def fail_with_none(message: str = ""):
+ def outer(func):
+ def inner(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ bt.logging.error(message)
+ bt.logging.error(str(e))
+ bt.logging.error(traceback.format_exc())
+ return None
+
+ return inner
+
+ return outer
+
+
+def on_block_interval(interval_attr_name):
+ """
+ Decorator for methods that should only execute at specific block intervals.
+
+ Args:
+ interval_attr_name: String name of the config attribute that specifies the interval
+ """
+
+ def decorator(func):
+ @functools.wraps(func)
+ async def wrapper(self, block, *args, **kwargs):
+ interval = getattr(self.config, interval_attr_name)
+ if interval is None:
+ bt.logging.error(f"No interval found for {interval_attr_name}")
+ if (
+ block == 0 or block % interval == 0
+ ): # Allow execution on block 0 for initialization
+ return await func(self, block, *args, **kwargs)
+ return None
+
+ return wrapper
+
+ return decorator
+
+
+class ExitContext:
+ """
+ Using this as a class lets us pass this to other threads
+ """
+
+ isExiting: bool = False
+
+ def startExit(self, *_):
+ if self.isExiting:
+ exit()
+ self.isExiting = True
+
+ def __bool__(self):
+ return self.isExiting
+
+
+def get_metadata(media_path):
+ """Get metadata for a media file if it exists."""
+ base_path = os.path.splitext(media_path)[0]
+ json_path = f"{base_path}.json"
+
+ if os.path.exists(json_path):
+ try:
+ with open(json_path, "r") as f:
+ return json.load(f)
+ except json.JSONDecodeError:
+ bt.logging.error(f"Warning: Could not parse JSON file: {json_path}")
+ return {}
+ return {}
+
+
+def get_file_modality(filepath: str) -> str:
+ """
+ Determine the type of media file based on its extension.
+
+ Args:
+ filepath: Path to the media file
+
+ Returns:
+ "image", "video", or "file" based on the file extension
+ """
+ ext = os.path.splitext(filepath)[1].lower()
+ if ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
+ return "image"
+ elif ext in [".mp4", ".avi", ".mov", ".webm", ".mkv", ".flv"]:
+ return "video"
+ else:
+ return "file"
diff --git a/bitmind/utils/config.py b/bitmind/utils/config.py
deleted file mode 100644
index ab9c1da4..00000000
--- a/bitmind/utils/config.py
+++ /dev/null
@@ -1,344 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# Copyright © 2023 Opentensor Foundation
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-import os
-import subprocess
-import argparse
-import bittensor as bt
-from bitmind.utils.logging import setup_events_logger
-
-
-def get_device():
- try:
- output = subprocess.check_output(["nvidia-smi", "-L"], stderr=subprocess.STDOUT)
- if "NVIDIA" in output.decode("utf-8"):
- return "cuda"
- except Exception:
- pass
- try:
- output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
- if "release" in output:
- return "cuda"
- except Exception:
- pass
- return "cpu"
-
-
-def replace_empty_with_default(args: argparse.Namespace, parser: argparse.ArgumentParser):
- for action in parser._actions:
- arg_name = action.dest
- if isinstance(getattr(args, arg_name), str) and getattr(args, arg_name) == "":
- setattr(args, arg_name, action.default)
- return args
-
-
-def check_config(cls, config: "bt.Config"):
- r"""Checks/validates the config namespace object."""
- bt.logging.check_config(config)
-
- full_path = os.path.expanduser(
- "{}/{}/{}/netuid{}/{}".format(
- config.logging.logging_dir, # TODO: change from ~/.bittensor/miners to ~/.bittensor/neurons
- config.wallet.name,
- config.wallet.hotkey,
- config.netuid,
- config.neuron.name,
- )
- )
- print("full path:", full_path)
- config.neuron.full_path = os.path.expanduser(full_path)
- if not os.path.exists(config.neuron.full_path):
- os.makedirs(config.neuron.full_path, exist_ok=True)
-
- config = replace_empty_with_default(config, add_all_args(cls))
- config.logging.info = True
-
- if not config.neuron.dont_save_events:
- # Add custom event logger for the events.
- events_logger = setup_events_logger(
- config.neuron.full_path, config.neuron.events_retention_size
- )
- bt.logging.register_primary_logger(events_logger.name)
- if config.logging.debug:
- bt.logging.enable_debug()
- elif config.logging.trace:
- bt.logging.enable_trace()
- elif config.logging.info:
- bt.logging.enable_info()
- else:
- bt.logging.enable_default()
-
-
-def add_args(cls, parser):
- """
- Adds relevant arguments to the parser for operation.
- """
-
- parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1)
-
- parser.add_argument(
- "--neuron.epoch_length",
- type=int,
- help="The default epoch length (how often we set weights, measured in 12 second blocks).",
- default=100,
- )
-
- parser.add_argument(
- "--mock",
- action="store_true",
- help="Mock neuron and all network components.",
- default=False,
- )
-
- parser.add_argument(
- "--neuron.events_retention_size",
- type=str,
- help="Events retention size.",
- default=2 * 1024 * 1024 * 1024, # 2 GB
- )
-
- parser.add_argument(
- "--neuron.dont_save_events",
- action="store_true",
- help="If set, we dont save events to a log file.",
- default=False,
- )
-
- parser.add_argument(
- "--wandb.off",
- action="store_true",
- help="Turn off wandb.",
- default=False,
- )
-
- parser.add_argument(
- "--wandb.offline",
- action="store_true",
- help="Runs wandb in offline mode.",
- default=False,
- )
-
- parser.add_argument(
- "--wandb.notes",
- type=str,
- help="Notes to add to the wandb run.",
- default="",
- )
-
-
-def add_miner_args(cls, parser):
- """Add miner specific arguments to the parser."""
-
- parser.add_argument(
- "--neuron.image_detector_config",
- type=str,
- help=".yaml file name in base_miner/deepfake_detectors/configs/ to load for trained model.",
- default="camo.yaml",
- )
-
- parser.add_argument(
- "--neuron.image_detector",
- type=str,
- help="The DETECTOR_REGISTRY module name of the DeepfakeDetector subclass to use for inference.",
- default="CAMO",
- )
-
- parser.add_argument(
- "--neuron.image_detector_device",
- type=str,
- help="Device to run image detection model on.",
- default=get_device(),
- )
-
- parser.add_argument(
- "--neuron.video_detector_config",
- type=str,
- help=".yaml file name in base_miner/deepfake_detectors/configs/ to load for trained model.",
- default="tall.yaml",
- )
-
- parser.add_argument(
- "--neuron.video_detector",
- type=str,
- help="The DETECTOR_REGISTRY module name of the DeepfakeDetector subclass to use for inference.",
- default="TALL",
- )
-
- parser.add_argument(
- "--neuron.video_detector_device",
- type=str,
- help="Device to run image detection model on.",
- default=get_device(),
- )
-
- parser.add_argument(
- "--neuron.name",
- type=str,
- help="Trials for this neuron go in neuron.root / (wallet_cold - wallet_hot) / neuron.name. ",
- default="miner",
- )
-
- parser.add_argument(
- "--blacklist.force_validator_permit",
- action="store_true",
- help="If set, we will force incoming requests to have a permit.",
- default=False,
- )
-
- parser.add_argument(
- "--blacklist.allow_non_registered",
- action="store_true",
- help="If set, miners will accept queries from non registered entities. (Dangerous!)",
- default=False,
- )
-
- parser.add_argument(
- "--wandb.project_name",
- type=str,
- default="template-miners",
- help="Wandb project to log to.",
- )
-
- parser.add_argument(
- "--wandb.entity",
- type=str,
- default="opentensor-dev",
- help="Wandb entity to log to.",
- )
-
-
-def add_validator_args(cls, parser):
- """Add validator specific arguments to the parser."""
-
- parser.add_argument(
- "--neuron.device",
- type=str,
- help="Device to run on.",
- default=get_device(),
- )
-
- parser.add_argument(
- "--neuron.prompt_type",
- type=str,
- help="Choose 'annotation' to generate prompts from BLIP-2 annotations of real images, or 'random' for arbitrary prompts.",
- default='annotation',
- )
-
- parser.add_argument(
- "--neuron.name",
- type=str,
- help="Trials for this neuron go in neuron.root / (wallet_cold - wallet_hot) / neuron.name. ",
- default="validator",
- )
-
- parser.add_argument(
- "--neuron.timeout",
- type=float,
- help="The timeout for each forward call in seconds.",
- default=10,
- )
-
- parser.add_argument(
- "--neuron.num_concurrent_forwards",
- type=int,
- help="The number of concurrent forwards running at any time.",
- default=1,
- )
-
- parser.add_argument(
- "--neuron.sample_size",
- type=int,
- help="The number of miners to query in a single step.",
- default=50,
- )
-
- parser.add_argument(
- "--neuron.disable_set_weights",
- action="store_true",
- help="Disables setting weights.",
- default=False,
- )
-
- parser.add_argument(
- "--neuron.moving_average_alpha",
- type=float,
- help="Moving average alpha parameter, how much to add of the new observation.",
- default=0.05,
- )
-
- parser.add_argument(
- "--neuron.axon_off",
- "--axon_off",
- action="store_true",
- # Note: the validator needs to serve an Axon with their IP or they may
- # be blacklisted by the firewall of serving peers on the network.
- help="Set this flag to not attempt to serve an Axon.",
- default=False,
- )
-
- parser.add_argument(
- "--neuron.vpermit_tao_limit",
- type=int,
- help="The maximum number of TAO allowed to query a validator with a vpermit.",
- default=4096,
- )
-
- parser.add_argument(
- "--wandb.project_name",
- type=str,
- help="The name of the project where you are sending the new run.",
- default="template-validators",
- )
-
- parser.add_argument(
- "--wandb.entity",
- type=str,
- help="The name of the project where you are sending the new run.",
- default="opentensor-dev",
- )
-
- parser.add_argument(
- "--proxy.port",
- type=int,
- help="The port to run the proxy on.",
- default=10913
- )
-
- parser.add_argument(
- "--proxy.proxy_client_url",
- type=str,
- help="The url initialize credentials for proxy.",
- default="https://subnet-api.bitmindlabs.ai"
- )
-
-
-def add_all_args(cls):
- parser = argparse.ArgumentParser()
- bt.wallet.add_args(parser)
- bt.subtensor.add_args(parser)
- bt.logging.add_args(parser)
- bt.axon.add_args(parser)
- cls.add_args(parser)
- return parser
-
-
-def config(cls):
- """
- Returns the configuration object specific to this miner or validator after adding relevant arguments.
- """
- parser = add_all_args(cls)
- return bt.config(parser)
diff --git a/bitmind/utils/image_transforms.py b/bitmind/utils/image_transforms.py
deleted file mode 100644
index c58a9d6f..00000000
--- a/bitmind/utils/image_transforms.py
+++ /dev/null
@@ -1,518 +0,0 @@
-import math
-import random
-from PIL import Image
-import torchvision.transforms as transforms
-import torchvision.transforms.functional as F
-import numpy as np
-import torch
-import cv2
-
-from bitmind.validator.config import TARGET_IMAGE_SIZE
-
-
-def center_crop():
- def crop(img):
- m = min(img.size)
- return transforms.CenterCrop(m)(img)
- return crop
-
-
-class RandomResizedCropWithParams(transforms.RandomResizedCrop):
- def __init__(self, *args, include_point=None, **kwargs):
- super().__init__(*args, **kwargs)
- self.params = None
- self.include_point = include_point
- print(f"created RRC with point included: {self.include_point}")
-
- def forward(self, img, crop_params=None):
- """
- Args:
- img: PIL Image to be cropped and resized
- crop_params: Optional pre-computed crop parameters (i, j, h, w)
- """
- if crop_params is None:
- i, j, h, w = super().get_params(img, self.scale, self.ratio)
- if self.include_point is not None:
- x, y = self.include_point
- width, height = img.shape[1:]
-
- # adjust crop to keep mask point
- if x < j:
- j = max(0, x - 10)
- elif x > j + w:
- j = min(width - w, x - w + 10)
-
- if y < i:
- i = max(0, y - 10)
- elif y > i + h:
- i = min(height - h, y - h + 10)
- else:
- i, j, h, w = crop_params
-
- self.params = {'crop_params': (i, j, h, w)}
- return F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias)
-
-
-class RandomHorizontalFlipWithParams(transforms.RandomHorizontalFlip):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.params = {}
-
- def forward(self, img, do_flip=None):
- if do_flip is not None:
- self.params = {'do_flip': do_flip}
- return transforms.functional.hflip(img) if do_flip else img
- elif not hasattr(self, 'params'):
- do_flip = torch.rand(1) < self.p
- self.params = {'do_flip': do_flip}
- return transforms.functional.hflip(img) if do_flip else img
- else:
- return transforms.functional.hflip(img) if self.params.get('do_flip', False) else img
-
-
-class RandomVerticalFlipWithParams(transforms.RandomVerticalFlip):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.params = {}
-
- def forward(self, img, do_flip=None):
- if do_flip is not None:
- self.params = {'do_flip': do_flip}
- return transforms.functional.vflip(img) if do_flip else img
- elif not hasattr(self, 'params'):
- do_flip = torch.rand(1) < self.p
- self.params = {'do_flip': do_flip}
- return transforms.functional.vflip(img) if do_flip else img
- else:
- return transforms.functional.vflip(img) if self.params.get('do_flip', False) else img
-
-
-class RandomRotationWithParams(transforms.RandomRotation):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.params = None
-
- def forward(self, img, angle=None):
- if angle is None:
- angle = self.get_params(self.degrees)
- self.params = {'angle': angle}
- return transforms.functional.rotate(img, angle)
-
-
-class ConvertToRGB:
- def __call__(self, img):
- img = img.convert('RGB')
- return img
-
-
-# DeeperForensics Distortion Functions
-def get_distortion_parameter(distortion_type, level):
- """Get distortion parameter based on type and level.
-
- Parameters are arranged from least severe (level 1) to most severe (level 5).
- Each distortion type has different parameter behavior:
-
- CS (Color Saturation):
- - Range: [0.4 -> 0.0]
- - Lower values = worse distortion
- - 0.4 = slight desaturation
- - 0.0 = complete desaturation (grayscale)
-
- CC (Color Contrast):
- - Range: [0.85 -> 0.35]
- - Lower values = worse distortion
- - 0.85 = slight contrast reduction
- - 0.35 = severe contrast reduction
-
- BW (Block Wise):
- - Range: [16 -> 80]
- - Higher values = worse distortion
- - Controls number of random blocks added
- - 16 = few blocks
- - 80 = many blocks
-
- GNC (Gaussian Noise Color):
- - Range: [0.001 -> 0.05]
- - Higher values = worse distortion
- - Controls noise variance
- - 0.001 = subtle noise
- - 0.05 = very noisy
-
- GB (Gaussian Blur):
- - Range: [7 -> 21]
- - Higher values = worse distortion
- - Controls blur kernel size
- - 7 = slight blur
- - 21 = heavy blur
-
- JPEG (JPEG Compression):
- - Range: [2 -> 6]
- - Higher values = worse distortion
- - Controls downsampling factor
- - 2 = mild compression
- - 6 = severe compression
- """
- param_dict = {
- 'CS': [0.4, 0.3, 0.2, 0.1, 0.0],
- 'CC': [0.85, 0.725, 0.6, 0.475, 0.35],
- 'BW': [16, 32, 48, 64, 80],
- 'GNC': [0.001, 0.002, 0.005, 0.01, 0.05],
- 'GB': [7, 9, 13, 17, 21],
- 'JPEG': [2, 3, 4, 5, 6]
- }
- return param_dict[distortion_type][level - 1]
-
-
-def get_distortion_function(distortion_type):
- """Get distortion function based on type."""
- func_dict = {
- 'CS': color_saturation,
- 'CC': color_contrast,
- 'BW': block_wise,
- 'GNC': gaussian_noise_color,
- 'GB': gaussian_blur,
- 'JPEG': jpeg_compression
- }
- return func_dict[distortion_type]
-
-
-def rgb_to_bgr(tensor_img):
- """Convert a PyTorch tensor image from RGB to BGR format.
-
- Args:
- tensor_img: Tensor in format (C, H, W)
- """
- if tensor_img.shape[0] == 3:
- tensor_img = tensor_img[[2, 1, 0], ...]
- return tensor_img
-
-
-def bgr_to_rgb(tensor_img):
- """Convert a PyTorch tensor image from BGR to RGB format.
-
- Args:
- tensor_img: Tensor in format (C, H, W) with values in [0, 1]
- """
- if tensor_img.shape[0] == 3:
- tensor_img = tensor_img[[2, 1, 0], ...]
- return tensor_img
-
-
-def bgr2ycbcr(img_bgr):
- """Convert BGR image to YCbCr color space."""
- img_bgr = img_bgr.astype(np.float32)
- img_ycrcb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2YCR_CB)
- img_ycbcr = img_ycrcb[:, :, (0, 2, 1)].astype(np.float32)
- img_ycbcr[:, :, 0] = (img_ycbcr[:, :, 0] * (235 - 16) + 16) / 255.0
- img_ycbcr[:, :, 1:] = (img_ycbcr[:, :, 1:] * (240 - 16) + 16) / 255.0
- return img_ycbcr
-
-
-def ycbcr2bgr(img_ycbcr):
- """Convert YCbCr image to BGR color space."""
- img_ycbcr = img_ycbcr.astype(np.float32)
- img_ycbcr[:, :, 0] = (img_ycbcr[:, :, 0] * 255.0 - 16) / (235 - 16)
- img_ycbcr[:, :, 1:] = (img_ycbcr[:, :, 1:] * 255.0 - 16) / (240 - 16)
- img_ycrcb = img_ycbcr[:, :, (0, 2, 1)].astype(np.float32)
- img_bgr = cv2.cvtColor(img_ycrcb, cv2.COLOR_YCR_CB2BGR)
- return img_bgr
-
-
-def color_saturation(img, param):
- """Apply color saturation distortion."""
- ycbcr = bgr2ycbcr(img)
- ycbcr[:, :, 1] = 0.5 + (ycbcr[:, :, 1] - 0.5) * param
- ycbcr[:, :, 2] = 0.5 + (ycbcr[:, :, 2] - 0.5) * param
- img = ycbcr2bgr(ycbcr).astype(np.uint8)
- return img
-
-
-def color_contrast(img, param):
- """Apply color contrast distortion."""
- img = img.astype(np.float32) * param
- return img.astype(np.uint8)
-
-
-def block_wise(img, param):
- """Apply block-wise distortion."""
- width = 8
- block = np.ones((width, width, 3)).astype(int) * 128
- param = min(img.shape[0], img.shape[1]) // 256 * param
- for _ in range(param):
- r_w = random.randint(0, img.shape[1] - 1 - width)
- r_h = random.randint(0, img.shape[0] - 1 - width)
- img[r_h:r_h + width, r_w:r_w + width, :] = block
- return img
-
-
-def gaussian_noise_color(img, param):
- """Apply colored Gaussian noise."""
- ycbcr = bgr2ycbcr(img) / 255
- size_a = ycbcr.shape
- b = (ycbcr + math.sqrt(param) * np.random.randn(size_a[0], size_a[1], size_a[2])) * 255
- b = ycbcr2bgr(b)
- return np.clip(b, 0, 255).astype(np.uint8)
-
-
-def gaussian_blur(img, param):
- """Apply Gaussian blur."""
- return cv2.GaussianBlur(img, (param, param), param * 1.0 / 6)
-
-
-def jpeg_compression(img, param):
- """Apply JPEG compression distortion."""
- h, w, _ = img.shape
- s_h = h // param
- s_w = w // param
- img = cv2.resize(img, (s_w, s_h))
- return cv2.resize(img, (w, h))
-
-
-class ApplyDeeperForensicsDistortion:
- """Wrapper for applying DeeperForensics distortions."""
-
- def __init__(self, distortion_type, level_min=0, level_max=3):
- self.distortion_type = distortion_type
- self.level_min = level_min
- self.level_max = level_max
- self.params = {}
-
- def __call__(self, img, level=None):
- if level is None:
- self.level = random.randint(self.level_min, self.level_max)
- else:
- self.level = level
-
- if self.level > 0:
- self.distortion_param = get_distortion_parameter(self.distortion_type, self.level)
- self.distortion_func = get_distortion_function(self.distortion_type)
- else:
- self.distortion_func = None
- self.distortion_param = None
-
- if not self.distortion_func:
- return img
-
- if isinstance(img, torch.Tensor):
- img = rgb_to_bgr(img)
- img = img.permute(1, 2, 0).cpu().numpy()
- img = (img * 255).astype(np.uint8)
-
- self.params = {'level': self.level}
- img = self.distortion_func(img, self.distortion_param)
-
- if isinstance(img, np.ndarray):
- img = torch.from_numpy(img.astype(np.float32) / 255.0)
- img = img.permute(2, 0, 1)
- img = bgr_to_rgb(img)
-
- return img
-
-
-class CLAHE:
- """Contrast Limited Adaptive Histogram Equalization."""
-
- def __init__(self):
- self.clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8,8))
-
- def __call__(self, image):
- # Convert PIL image to NumPy array
- image_np = np.array(image)
-
- # Apply CLAHE to each channel separately if it's a color image
- if len(image_np.shape) == 3: # Color image
- channels = cv2.split(image_np)
- clahe_channels = [self.clahe.apply(ch) for ch in channels]
- clahe_image_np = cv2.merge(clahe_channels)
- else: # Grayscale image
- clahe_image_np = self.clahe.apply(image_np)
-
- # Convert back to PIL image
- clahe_image = Image.fromarray(clahe_image_np)
-
- return clahe_image
-
-
-class TensorCLAHE:
- def __init__(self):
- self.clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8,8))
-
- def __call__(self, tensor):
- # Convert tensor to numpy array (H,W,C) format
- img_np = tensor.permute(1, 2, 0).numpy() * 255
- img_np = img_np.astype(np.uint8)
-
- # Apply CLAHE to each channel
- channels = cv2.split(img_np)
- clahe_channels = [self.clahe.apply(ch) for ch in channels]
- clahe_image_np = cv2.merge(clahe_channels)
-
- # Convert back to tensor
- tensor = torch.from_numpy(clahe_image_np).float() / 255.0
- return tensor.permute(2, 0, 1)
-
-
-class ComposeWithParams:
- def __init__(self, transforms):
- self.transforms = transforms
- self.params = {}
-
- def __call__(self, input_data, clear_params=True):
- if clear_params:
- self.params = {}
-
- output_data = []
- list_input = True
- if not isinstance(input_data, list):
- input_data = [input_data]
- list_input = False
-
- for img in input_data:
- for transform in self.transforms:
- try:
- name = transform.__name__
- except AttributeError:
- name = transform.__class__.__name__
-
- if name in self.params:
- img = transform(img, **self.params[name])
- else:
- img = transform(img)
- if hasattr(transform, 'params'):
- self.params[name] = transform.params
- output_data.append(img)
-
- if list_input:
- return output_data
- return output_data[0]
-
-
-# Transform configurations
-def get_base_transforms(target_image_size=TARGET_IMAGE_SIZE):
- return ComposeWithParams([
- ConvertToRGB(),
- center_crop(),
- transforms.Resize(target_image_size),
- transforms.ToTensor()
- ])
-
-
-def get_random_augmentations(target_image_size=TARGET_IMAGE_SIZE, mask_point=None):
- return ComposeWithParams([
- ConvertToRGB(),
- transforms.ToTensor(),
- RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR),
- RandomResizedCropWithParams(
- TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0), include_point=mask_point),
- RandomHorizontalFlipWithParams(),
- RandomVerticalFlipWithParams()
- ])
-
-def get_ucf_base_transforms(target_image_size=TARGET_IMAGE_SIZE):
- return transforms.Compose([
- ConvertToRGB(),
- center_crop(),
- transforms.Resize(target_image_size),
- CLAHE(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
-
-def get_tall_base_transforms(target_image_size=TARGET_IMAGE_SIZE):
- return ComposeWithParams([
- transforms.Resize(target_image_size),
- transforms.ToTensor()
- ])
-
-# Medium difficulty transforms with mild distortions
-def get_random_augmentations_medium(target_image_size=TARGET_IMAGE_SIZE, mask_point=None):
- return ComposeWithParams([
- ConvertToRGB(),
- transforms.ToTensor(),
- RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR),
- RandomResizedCropWithParams(
- TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0), include_point=mask_point),
- RandomHorizontalFlipWithParams(),
- RandomVerticalFlipWithParams(),
- ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=1),
- ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=1),
- ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=1)
- ])
-
-# Hard difficulty transforms with more severe distortions
-def get_random_augmentations_hard(target_image_size=TARGET_IMAGE_SIZE, mask_point=None):
- return ComposeWithParams([
- ConvertToRGB(),
- transforms.ToTensor(),
- RandomRotationWithParams(20, interpolation=transforms.InterpolationMode.BILINEAR),
- RandomResizedCropWithParams(
- TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0), include_point=mask_point),
- RandomHorizontalFlipWithParams(),
- RandomVerticalFlipWithParams(),
- ApplyDeeperForensicsDistortion('CS', level_min=0, level_max=2),
- ApplyDeeperForensicsDistortion('CC', level_min=0, level_max=2),
- ApplyDeeperForensicsDistortion('JPEG', level_min=0, level_max=2),
- ApplyDeeperForensicsDistortion('GNC', level_min=0, level_max=2),
- ApplyDeeperForensicsDistortion('GB', level_min=0, level_max=2)
- ])
-
-
-def apply_augmentation_by_level(
- image,
- target_image_size,
- mask_point=None,
- level_probs={
- 0: 0.25, # No augmentations (base transforms)
- 1: 0.25, # Basic augmentations
- 2: 0.25, # Medium distortions
- 3: 0.25 # Hard distortions
- }):
- """
- Apply image transformations based on randomly selected level.
-
- Args:
- image: PIL Image to transform
- level_probs: dict with augmentation levels and their probabilities.
- Default probabilities:
- - Level 0 (25%): No augmentations (base transforms)
- - Level 1 (45%): Basic augmentations
- - Level 2 (15%): Medium distortions
- - Level 3 (15%): Hard distortions
-
- Returns:
- tuple: (transformed_image, level, transform_params)
-
- Raises:
- ValueError: If probabilities don't sum to 1.0 (within floating point precision)
- """
- # Validate probabilities
- if not math.isclose(sum(level_probs.values()), 1.0, rel_tol=1e-9):
- raise ValueError("Probabilities of levels must sum to 1.0")
-
- # Calculate cumulative probabilities
- cumulative_probs = {}
- cumsum = 0
- for level, prob in sorted(level_probs.items()):
- cumsum += prob
- cumulative_probs[level] = cumsum
-
- # Select augmentation level
- rand_val = np.random.random()
- for curr_level, cum_prob in cumulative_probs.items():
- if rand_val <= cum_prob:
- level = curr_level
- break
-
- # Apply appropriate transform
- if level == 0:
- tforms = get_base_transforms(target_image_size)
- elif level == 1:
- tforms = get_random_augmentations(target_image_size, mask_point)
- elif level == 2:
- tforms = get_random_augmentations_medium(target_image_size, mask_point)
- else: # level == 3
- tforms = get_random_augmentations_hard(target_image_size, mask_point)
-
- transformed = tforms(image)
-
- return transformed, level, tforms.params
diff --git a/bitmind/utils/logging.py b/bitmind/utils/logging.py
deleted file mode 100644
index eea6eb0c..00000000
--- a/bitmind/utils/logging.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import os
-import logging
-from logging.handlers import RotatingFileHandler
-
-EVENTS_LEVEL_NUM = 38
-DEFAULT_LOG_BACKUP_COUNT = 10
-
-
-def setup_events_logger(full_path, events_retention_size):
- logging.addLevelName(EVENTS_LEVEL_NUM, "EVENT")
-
- logger = logging.getLogger("event")
- logger.setLevel(EVENTS_LEVEL_NUM)
-
- def event(self, message, *args, **kws):
- if self.isEnabledFor(EVENTS_LEVEL_NUM):
- self._log(EVENTS_LEVEL_NUM, message, args, **kws)
-
- logging.Logger.event = event
-
- formatter = logging.Formatter(
- "%(asctime)s | %(levelname)s | %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- )
-
- file_handler = RotatingFileHandler(
- os.path.join(full_path, "events.log"),
- maxBytes=events_retention_size,
- backupCount=DEFAULT_LOG_BACKUP_COUNT,
- )
- file_handler.setFormatter(formatter)
- file_handler.setLevel(EVENTS_LEVEL_NUM)
- logger.addHandler(file_handler)
-
- return logger
-
diff --git a/bitmind/utils/misc.py b/bitmind/utils/misc.py
deleted file mode 100644
index 80b4e614..00000000
--- a/bitmind/utils/misc.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# Copyright © 2023 Opentensor Foundation
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import time
-import math
-import hashlib as rpccheckhealth
-from math import floor
-from typing import Callable, Any
-from functools import lru_cache, update_wrapper
-
-
-# LRU Cache with TTL
-def ttl_cache(maxsize: int = 128, typed: bool = False, ttl: int = -1):
- """
- Decorator that creates a cache of the most recently used function calls with a time-to-live (TTL) feature.
- The cache evicts the least recently used entries if the cache exceeds the `maxsize` or if an entry has
- been in the cache longer than the `ttl` period.
-
- Args:
- maxsize (int): Maximum size of the cache. Once the cache grows to this size, subsequent entries
- replace the least recently used ones. Defaults to 128.
- typed (bool): If set to True, arguments of different types will be cached separately. For example,
- f(3) and f(3.0) will be treated as distinct calls with distinct results. Defaults to False.
- ttl (int): The time-to-live for each cache entry, measured in seconds. If set to a non-positive value,
- the TTL is set to a very large number, effectively making the cache entries permanent. Defaults to -1.
-
- Returns:
- Callable: A decorator that can be applied to functions to cache their return values.
-
- The decorator is useful for caching results of functions that are expensive to compute and are called
- with the same arguments frequently within short periods of time. The TTL feature helps in ensuring
- that the cached values are not stale.
-
- Example:
- @ttl_cache(ttl=10)
- def get_data(param):
- # Expensive data retrieval operation
- return data
- """
- if ttl <= 0:
- ttl = 65536
- hash_gen = _ttl_hash_gen(ttl)
-
- def wrapper(func: Callable) -> Callable:
- @lru_cache(maxsize, typed)
- def ttl_func(ttl_hash, *args, **kwargs):
- return func(*args, **kwargs)
-
- def wrapped(*args, **kwargs) -> Any:
- th = next(hash_gen)
- return ttl_func(th, *args, **kwargs)
-
- return update_wrapper(wrapped, func)
-
- return wrapper
-
-
-def _ttl_hash_gen(seconds: int):
- """
- Internal generator function used by the `ttl_cache` decorator to generate a new hash value at regular
- time intervals specified by `seconds`.
-
- Args:
- seconds (int): The number of seconds after which a new hash value will be generated.
-
- Yields:
- int: A hash value that represents the current time interval.
-
- This generator is used to create time-based hash values that enable the `ttl_cache` to determine
- whether cached entries are still valid or if they have expired and should be recalculated.
- """
- start_time = time.time()
- while True:
- yield floor((time.time() - start_time) / seconds)
-
-
-# 12 seconds updating block.
-@ttl_cache(maxsize=1, ttl=12)
-def ttl_get_block(self) -> int:
- """
- Retrieves the current block number from the blockchain. This method is cached with a time-to-live (TTL)
- of 12 seconds, meaning that it will only refresh the block number from the blockchain at most every 12 seconds,
- reducing the number of calls to the underlying blockchain interface.
-
- Returns:
- int: The current block number on the blockchain.
-
- This method is useful for applications that need to access the current block number frequently and can
- tolerate a delay of up to 12 seconds for the latest information. By using a cache with TTL, the method
- efficiently reduces the workload on the blockchain interface.
-
- Example:
- current_block = ttl_get_block(self)
-
- Note: self here is the miner or validator instance
- """
- return self.subtensor.get_current_block()
diff --git a/bitmind/utils/mock.py b/bitmind/utils/mock.py
deleted file mode 100644
index 8fae787f..00000000
--- a/bitmind/utils/mock.py
+++ /dev/null
@@ -1,227 +0,0 @@
-import time
-import asyncio
-import random
-import bittensor as bt
-import numpy as np
-from typing import List
-from PIL import Image
-
-from bitmind.validator.config import MODEL_NAMES
-from bitmind.validator.miner_performance_tracker import MinerPerformanceTracker
-
-
-def create_random_image():
- random_data = np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8)
- return Image.fromarray(random_data)
-
-
-class MockImageDataset:
- def __init__(
- self,
- huggingface_dataset_path: str,
- huggingface_datset_split: str = 'train',
- huggingface_datset_name: str = None,
- create_splits: bool = False,
- download_mode: str = None):
-
- self.huggingface_dataset_path = huggingface_dataset_path
- self.huggingface_dataset_name = huggingface_datset_name
- self.dataset = ""
- self.sampled_images_idx = []
-
- def __getitem__(self, index: int) -> dict:
- return {
- 'image': create_random_image(),
- 'id': index,
- 'source': self.huggingface_dataset_path
- }
-
- def __len__(self):
- return 100 # mock length
-
- def sample(self, k=1):
- return [self.__getitem__(i) for i in range(k)], [i for i in range(k)]
-
-
-class MockSyntheticDataGenerator:
- def __init__(self, prompt_type, use_random_t2v_model, t2v_model_name):
- self.prompt_type = prompt_type
- self.t2v_model_name = t2v_model_name
- self.use_random_t2v_model = use_random_t2v_model
-
- def generate(self, k=1, real_images=None, modality='image'):
- if self.use_random_t2v_model:
- self.load_t2v_model('random')
- else:
- self.load_t2v_model(self.t2v_model_name)
-
- return [{
- 'prompt': f'mock {self.prompt_type} prompt',
- 'image': create_random_image(),
- 'id': i
- } for i in range(k)]
-
- def load_diffuser(self, t2v_model_name) -> None:
- """
- loads a huggingface diffuser model.
- """
- if t2v_model_name == 'random':
- t2v_model_name = np.random.choice(MODEL_NAMES, 1)[0]
- self.t2v_model_name = t2v_model_name
-
-
-class MockValidator:
- def __init__(self, config):
- self.config = config
- subtensor = MockSubtensor(config.netuid, wallet=bt.MockWallet())
-
- self.performance_tracker = MinerPerformanceTracker()
-
- self.metagraph = MockMetagraph(
- netuid=config.netuid,
- subtensor=subtensor
- )
- self.dendrite = MockDendrite(bt.MockWallet())
- self.real_image_datasets = [
- MockImageDataset(
- f"fake-path/dataset-{i}",
- 'train',
- None,
- False)
- for i in range(3)
- ]
- self.synthetic_data_generator = MockSyntheticDataGenerator(
- prompt_type='annotation', use_random_diffuser=True, diffuser_name=None)
- self.total_real_images = sum([len(ds) for ds in self.real_image_datasets])
- self.scores = np.zeros(self.metagraph.n, dtype=np.float32)
- self._fake_prob = config.fake_prob
-
- def update_scores(self, rewards, miner_uids):
- pass
-
- def save_miner_history(self):
- pass
-
-
-class MockSubtensor(bt.MockSubtensor):
- def __init__(self, netuid, n=16, wallet=None, network="mock"):
- super().__init__(network=network)
- bt.MockSubtensor.reset() # reset chain state so test cases don't interfere with one another
-
- if not self.subnet_exists(netuid):
- self.create_subnet(netuid)
-
- # Register ourself (the validator) as a neuron at uid=0
- if wallet is not None:
- try:
- self.force_register_neuron(
- netuid=netuid,
- hotkey=wallet.hotkey.ss58_address,
- coldkey=wallet.coldkey.ss58_address,
- balance=100000,
- stake=100000,
- )
- except Exception as e:
- print(f"Skipping force_register_neuron: {e}")
-
- # Register n mock neurons who will be miners
- for i in range(1, n + 1):
- try:
- self.force_register_neuron(
- netuid=netuid,
- hotkey=f"miner-hotkey-{i}",
- coldkey="mock-coldkey",
- balance=100000,
- stake=100000,
- )
- except Exception as e:
- print(f"Skipping force_register_neuron: {e}")
-
-
-class MockMetagraph(bt.metagraph):
- def __init__(self, netuid, network="mock", subtensor=None):
- super().__init__(netuid=netuid, network=network, sync=False)
- self.default_ip = "127.0.0.0"
- self.default_port = 8092
-
- if subtensor is not None:
- self.subtensor = subtensor
- self.sync(subtensor=subtensor)
-
- for axon in self.axons:
- axon.ip = self.default_ip
- axon.port = self.default_port
-
- bt.logging.info(f"Metagraph: {self}")
- bt.logging.info(f"Axons: {self.axons}")
-
-
-class MockDendrite(bt.dendrite):
- """
- Replaces a real bittensor network request with a mock request that just returns some static response for all axons that are passed and adds some random delay.
- """
-
- def __init__(self, wallet):
- super().__init__(wallet)
-
- async def forward(
- self,
- axons: List[bt.axon],
- synapse: bt.Synapse = bt.Synapse(),
- timeout: float = 12,
- deserialize: bool = True,
- run_async: bool = True,
- streaming: bool = False,
- ):
- if streaming:
- raise NotImplementedError("Streaming not implemented yet.")
-
- async def query_all_axons(streaming: bool):
- """Queries all axons for responses."""
-
- async def single_axon_response(i, axon):
- """Queries a single axon for a response."""
-
- start_time = time.time()
- s = synapse.copy()
- # Attach some more required data so it looks real
- s = self.preprocess_synapse_for_request(axon, s, timeout)
- # We just want to mock the response, so we'll just fill in some data
- process_time = random.random()
- if process_time < timeout:
- s.dendrite.process_time = str(time.time() - start_time)
- # Update the status code and status message of the dendrite to match the axon
- # TODO (developer): replace with your own expected synapse data
- s.prediction = np.random.rand(1)[0]
- s.dendrite.status_code = 200
- s.dendrite.status_message = "OK"
- s.dendrite.process_time = str(process_time)
- else:
- s.prediction = -1
- s.dendrite.status_code = 408
- s.dendrite.status_message = "Timeout"
- s.dendrite.process_time = str(timeout)
-
- # Return the updated synapse object after deserializing if requested
- if deserialize:
- return s.deserialize()
- else:
- return s
-
- return await asyncio.gather(
- *(
- single_axon_response(i, target_axon)
- for i, target_axon in enumerate(axons)
- )
- )
-
- return await query_all_axons(streaming)
-
- def __str__(self) -> str:
- """
- Returns a string representation of the Dendrite object.
-
- Returns:
- str: The string representation of the Dendrite object in the format "dendrite()".
- """
- return "MockDendrite({})".format(self.keypair.ss58_address)
diff --git a/bitmind/utils/uids.py b/bitmind/utils/uids.py
deleted file mode 100644
index e0300402..00000000
--- a/bitmind/utils/uids.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import random
-import bittensor as bt
-import numpy as np
-from typing import List
-
-
-def check_uid_availability(
- metagraph: "bt.metagraph.Metagraph", uid: int, vpermit_tao_limit: int
-) -> bool:
- """Check if uid is available. The UID should be available if it is serving and has less than vpermit_tao_limit stake
- Args:
- metagraph (:obj: bt.metagraph.Metagraph): Metagraph object
- uid (int): uid to be checked
- vpermit_tao_limit (int): Validator permit tao limit
- Returns:
- bool: True if uid is available, False otherwise
- """
- # Filter non serving axons.
- if not metagraph.axons[uid].is_serving:
- return False
- # Filter validator permit > 1024 stake.
- if metagraph.validator_permit[uid]:
- if metagraph.S[uid] > vpermit_tao_limit:
- return False
- # Available otherwise.
- return True
-
-
-def get_random_uids(
- self, k: int, exclude: List[int] = None
-) -> np.ndarray:
- """Returns k available random uids from the metagraph.
- Args:
- k (int): Number of uids to return.
- exclude (List[int]): List of uids to exclude from the random sampling.
- Returns:
- uids (np.ndarray): Randomly sampled available uids.
- Notes:
- If `k` is larger than the number of available `uids`, set `k` to the number of available `uids`.
- """
- candidate_uids = []
- avail_uids = []
-
- for uid in range(self.metagraph.n.item()):
- uid_is_available = check_uid_availability(
- self.metagraph, uid, self.config.neuron.vpermit_tao_limit
- )
- uid_is_not_excluded = exclude is None or uid not in exclude
-
- if uid_is_available:
- avail_uids.append(uid)
- if uid_is_not_excluded:
- candidate_uids.append(uid)
- # If k is larger than the number of available uids, set k to the number of available uids.
- k = min(k, len(avail_uids))
- # Check if candidate_uids contain enough for querying, if not grab all avaliable uids
- available_uids = candidate_uids
- if len(candidate_uids) < k:
- available_uids += random.sample(
- [uid for uid in avail_uids if uid not in candidate_uids],
- k - len(candidate_uids),
- )
- uids = np.array(random.sample(available_uids, k))
- return uids
diff --git a/bitmind/utils/video_utils.py b/bitmind/utils/video_utils.py
deleted file mode 100644
index ffb2a7ff..00000000
--- a/bitmind/utils/video_utils.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import torch
-
-
-def pad_frames(x, divisible_by):
- """
- Pads the tensor `x` along the frame dimension (1) until the number of frames is divisible by `divisible_by`.
-
- Args:
- x (torch.Tensor): Input tensor of shape (batch_size, num_frames, channels, height, width).
- divisible_by (int): The divisor to make the number of frames divisible by.
-
- Returns:
- torch.Tensor: Padded tensor of shape (batch_size, adjusted_num_frames, channels, height, width).
- """
- num_frames = x.shape[1]
- frame_padding = (divisible_by - (num_frames % divisible_by)) % divisible_by
-
- if frame_padding > 0:
- padding_shape = (x.shape[0], frame_padding, x.shape[2], x.shape[3], x.shape[4])
- x_padding = torch.zeros(padding_shape, device=x.device) # Ensure padding is on the same device
- x = torch.cat((x, x_padding), dim=1)
-
- assert x.shape[1] % divisible_by == 0, (
- f'Frame number mismatch: got {x.shape[1]} frames, not divisible by {divisible_by}.'
- )
- return x
\ No newline at end of file
diff --git a/bitmind/validator/__init__.py b/bitmind/validator/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/bitmind/validator/cache/__init__.py b/bitmind/validator/cache/__init__.py
deleted file mode 100644
index 8858fff1..00000000
--- a/bitmind/validator/cache/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .base_cache import BaseCache
-from .image_cache import ImageCache
-from .video_cache import VideoCache
diff --git a/bitmind/validator/cache/base_cache.py b/bitmind/validator/cache/base_cache.py
deleted file mode 100644
index a42831d5..00000000
--- a/bitmind/validator/cache/base_cache.py
+++ /dev/null
@@ -1,335 +0,0 @@
-from abc import ABC, abstractmethod
-import asyncio
-from datetime import datetime
-from pathlib import Path
-import time
-from typing import Any, Dict, List, Optional, Union
-
-import bittensor as bt
-import huggingface_hub as hf_hub
-import numpy as np
-
-from .util import get_most_recent_update_time, seconds_to_str
-from .download import download_files, list_hf_files
-
-
-class BaseCache(ABC):
- """
- Abstract base class for managing file caches with compressed sources.
-
- This class provides the basic infrastructure for maintaining both a compressed
- source cache and an extracted cache, with automatic refresh intervals and
- background update tasks.
- """
-
- def __init__(
- self,
- cache_dir: Union[str, Path],
- file_extensions: List[str],
- compressed_file_extension: str,
- datasets: dict = None,
- extracted_update_interval: int = 4,
- compressed_update_interval: int = 12,
- num_sources_per_dataset: int = 1,
- max_compressed_size_gb: float = 100.0,
- max_extracted_size_gb: float = 10.0,
- ) -> None:
- """
- Initialize the base cache infrastructure.
-
- Args:
- cache_dir: Path to store extracted files
- extracted_update_interval: Hours between extracted cache updates
- compressed_update_interval: Hours between compressed cache updates
- file_extensions: List of valid file extensions for this cache type
- max_compressed_size_gb: Maximum size in GB for compressed cache directory
- max_extracted_size_gb: Maximum size in GB for extracted cache directory
- """
- self.cache_dir = Path(cache_dir)
- self.cache_dir.mkdir(exist_ok=True, parents=True)
-
- self.compressed_dir = self.cache_dir / 'sources'
- self.compressed_dir.mkdir(exist_ok=True, parents=True)
-
- self.datasets = datasets
-
- self.extracted_update_interval = extracted_update_interval * 60 * 60
- self.compressed_update_interval = compressed_update_interval * 60 * 60
- self.num_sources_per_dataset = num_sources_per_dataset
- self.file_extensions = file_extensions
- self.compressed_file_extension = compressed_file_extension
- self.max_compressed_size_bytes = max_compressed_size_gb * 1024 * 1024 * 1024
- self.max_extracted_size_bytes = max_extracted_size_gb * 1024 * 1024 * 1024
-
- def start_updater(self):
- """Start the background updater tasks for compressed and extracted caches."""
- if not self.datasets:
- bt.logging.error("No datasets configured. Cannot start cache updater.")
- return
-
- try:
- self.loop = asyncio.get_running_loop()
- except RuntimeError:
- self.loop = asyncio.get_event_loop()
-
- # Initialize caches, blocking to ensure data are available for validator
- bt.logging.info(f"Setting up cache at {self.cache_dir}")
- bt.logging.info(f"Clearing incomplete sources in {self.compressed_dir}")
- self._clear_incomplete_sources()
-
- if self._extracted_cache_empty():
- if self._compressed_cache_empty():
- bt.logging.info(f"Compressed cache {self.compressed_dir} empty; populating")
- # grab 1 zip first to ensure validator has available data
- for batch_size in [1, None]:
- self._refresh_compressed_cache(n_sources_per_dataset=1, n_datasets=batch_size)
- self._refresh_extracted_cache()
- else:
- bt.logging.info(f"Extracted cache {self.cache_dir} empty; populating")
- self._refresh_extracted_cache()
-
- # Start background tasks
- bt.logging.info(f"Starting background tasks")
- self._compressed_updater_task = self.loop.create_task(
- self._run_compressed_updater()
- )
- self._extracted_updater_task = self.loop.create_task(
- self._run_extracted_updater()
- )
-
- def _get_files(self, cache="extracted", max_depth: int = 3, group_by_source: bool = False) -> Union[List[Path], Dict[str, List[Path]]]:
- """Get list of files from the specified cache up to a maximum depth.
-
- Args:
- cache: Which cache to retrieve files from - either "extracted" or "compressed"
- max_depth: Maximum directory depth to search
- group_by_source: If True, returns a dictionary of files grouped by subdirectory
-
- Returns:
- Either a flat list of Path objects or a dictionary mapping subdirectories to lists of files
- """
- if cache == "extracted":
- base_dir = self.cache_dir
- extensions = self.file_extensions
- elif cache == "compressed":
- base_dir = self.compressed_dir
- extensions = [self.compressed_file_extension]
- else:
- raise ValueError(f"Invalid cache type: {cache}. Must be 'extracted' or 'compressed'")
-
- files = []
- for depth in range(max_depth + 1):
- pattern = '*/' * depth + '*'
- files.extend([
- f for f in base_dir.glob(pattern)
- if f.is_file() and any(f.suffix.lower() == ext.lower() for ext in extensions)
- ])
-
- if not group_by_source:
- return files
-
- # Group files by subdirectory
- subdirectory_files = {}
- for file in files:
- if file.exists():
- try:
- rel_path = file.relative_to(base_dir)
- subdir = str(rel_path.parent)
- except ValueError:
- # Fallback if relative_to fails
- subdir = str(file.parent)
-
- if subdir not in subdirectory_files:
- subdirectory_files[subdir] = []
- subdirectory_files[subdir].append(file)
-
- return subdirectory_files
-
- def _get_cached_files(self, max_depth: int = 3, group_by_source: bool = False) -> Union[List[Path], Dict[str, List[Path]]]:
- return self._get_files(cache="extracted", max_depth=max_depth, group_by_source=group_by_source)
-
- def _get_compressed_files(self, max_depth: int = 3, group_by_source: bool = False) -> Union[List[Path], Dict[str, List[Path]]]:
- return self._get_files(cache="compressed", max_depth=max_depth, group_by_source=group_by_source)
-
- def _extracted_cache_empty(self) -> bool:
- return len(self._get_cached_files()) == 0
-
- def _compressed_cache_empty(self) -> bool:
- return len(self._get_compressed_files()) == 0
-
- def prune_cache(self, cache="compressed") -> None:
- """Check extracted cache size and remove oldest files if over limit.
- Balances deletion across subdirectories to avoid emptying smaller directories."""
- if cache == 'compressed':
- cache_dir = self.compressed_dir
- max_size = self.max_compressed_size_bytes
- elif cache == 'extracted':
- cache_dir = self.cache_dir
- max_size = self.max_extracted_size_bytes
- else:
- raise ValueError(f"Invalid cache type: {cache}. Must be 'extracted' or 'compressed'")
-
- files_dict = self._get_files(cache=cache, group_by_source=True)
-
- all_files = [f for subdir_files in files_dict.values() for f in subdir_files]
- total_size = sum(f.stat().st_size for f in all_files if f.exists())
- total_size_gb = total_size / (1024*1024*1024)
- bt.logging.info(f"[{cache_dir}] Cache size: {len(all_files)} files | {total_size_gb:.6f} GB")
- if total_size <= max_size:
- return
-
- # xort each subdirectory's files by modification time
- for subdir in files_dict:
- files_dict[subdir] = sorted(
- files_dict[subdir],
- key=lambda f: f.stat().st_mtime if f.exists() else float('inf')
- )
-
- n_removed = 0
- bytes_removed = 0
- remaining_size = total_size
-
- bt.logging.info(f"[{cache_dir}] Pruning...")
-
- while remaining_size > max_size and any(len(files) > 0 for files in files_dict.values()):
- largest_subdir = max(
- files_dict.keys(),
- key=lambda subdir: len(files_dict[subdir]),
- default=None
- )
-
- if largest_subdir is None or not files_dict[largest_subdir]:
- break
-
- file = files_dict[largest_subdir].pop(0)
- try:
- if file.exists():
- file_size = file.stat().st_size
- file.unlink()
- json_file = file.with_suffix('.json')
- if json_file.exists():
- json_file.unlink()
-
- n_removed += 1
- bytes_removed += file_size
- remaining_size -= file_size
- except Exception as e:
- continue
-
- final_size = total_size - bytes_removed
- cache_gb = f"{final_size / (1024*1024*1024):.6f}".rstrip('0')
- removed_gb = f"{bytes_removed / (1024*1024*1024):.6f}".rstrip('0')
- bt.logging.info(f"[{cache_dir}] Removed {n_removed} ({removed_gb} GB) files. New cache size is {cache_gb} GB")
-
- async def _run_extracted_updater(self) -> None:
- """Asynchronously refresh extracted files according to update interval."""
- while True:
- try:
- self.prune_cache('extracted')
- last_update = get_most_recent_update_time(self.cache_dir)
- time_elapsed = time.time() - last_update
-
- if time_elapsed >= self.extracted_update_interval:
- bt.logging.info(f"[{self.cache_dir}] Refreshing cache")
- self._refresh_extracted_cache()
- bt.logging.info(f"[{self.cache_dir}] Cache refresh complete ")
-
- sleep_time = max(0, self.extracted_update_interval - time_elapsed)
- bt.logging.info(f"[{self.cache_dir}] Next media cache refresh in {seconds_to_str(sleep_time)}")
- await asyncio.sleep(sleep_time)
- except Exception as e:
- bt.logging.error(f"[{self.cache_dir}] Error in extracted cache update: {e}")
- await asyncio.sleep(60)
-
- async def _run_compressed_updater(self) -> None:
- """Asynchronously refresh compressed files according to update interval."""
- while True:
- try:
- self._clear_incomplete_sources()
- self.prune_cache('compressed')
- last_update = get_most_recent_update_time(self.compressed_dir)
- time_elapsed = time.time() - last_update
-
- if time_elapsed >= self.compressed_update_interval:
- cache_state_before = self._get_compressed_files()
- bt.logging.info(f"[{self.compressed_dir}] Refreshing cache")
- self._refresh_compressed_cache()
- bt.logging.info(f"[{self.compressed_dir}] Cache refresh complete")
- if set(cache_state_before) == set(self._get_compressed_files()):
- bt.logging.warning(f"[{self.compressed_dir}] All datasets small enough to store locally. Stopping updater.")
- return
-
- sleep_time = max(0, self.compressed_update_interval - time_elapsed)
- bt.logging.info(f"[{self.compressed_dir}] Next compressed cache refresh in {seconds_to_str(sleep_time)}")
- await asyncio.sleep(sleep_time)
- except Exception as e:
- bt.logging.error(f"[{self.compressed_dir}] Error in compressed cache update: {e}")
- await asyncio.sleep(60)
-
- def _refresh_compressed_cache(
- self,
- n_sources_per_dataset: Optional[int] = None,
- n_datasets: Optional[int] = None
- ) -> None:
- """
- Refresh the compressed file cache with new downloads.
- """
- if n_sources_per_dataset is None:
- n_sources_per_dataset = self.num_sources_per_dataset
-
- try:
- bt.logging.info(f"{len(self._get_compressed_files())} compressed sources currently cached")
-
- new_files: List[Path] = []
- for dataset in self.datasets[:n_datasets]:
- filenames = list_hf_files(
- repo_id=dataset['path'],
- extension=self.compressed_file_extension)
- remote_paths = [
- f"https://huggingface.co/datasets/{dataset['path']}/resolve/main/{f}"
- for f in filenames
- ]
- bt.logging.info(f"Downloading {n_sources_per_dataset} from {dataset['path']} to {self.compressed_dir}")
- new_files += download_files(
- urls=np.random.choice(remote_paths, n_sources_per_dataset),
- output_dir=self.compressed_dir / dataset['path'].split('/')[1])
-
- if new_files:
- bt.logging.info(f"{len(new_files)} new files added to {self.compressed_dir}")
- else:
- bt.logging.error(f"No new files were added to {self.compressed_dir}")
-
- except Exception as e:
- bt.logging.error(f"Error during compressed refresh for {self.compressed_dir}: {e}")
- raise
-
- def _refresh_extracted_cache(self, n_items_per_source: Optional[int] = None) -> None:
- """Refresh the extracted cache with new selections."""
- bt.logging.info(f"{len(self._get_cached_files())} media files currently cached")
- new_files = self._extract_random_items(n_items_per_source)
- if new_files:
- bt.logging.info(f"{len(new_files)} new files added to {self.cache_dir}")
- else:
- bt.logging.error(f"No new files were added to {self.cache_dir}")
-
- @abstractmethod
- def _extract_random_items(self, n_items_per_source: Optional[int] = None) -> List[Path]:
- """Remove any incomplete or corrupted source files from cache."""
- pass
-
- @abstractmethod
- def _clear_incomplete_sources(self) -> None:
- """Remove any incomplete or corrupted source files from cache."""
- pass
-
- @abstractmethod
- def sample(self, num_samples: int) -> Optional[Dict[str, Any]]:
- """Sample random items from the cache."""
- pass
-
- def __del__(self) -> None:
- """Cleanup background tasks on deletion."""
- if hasattr(self, '_extracted_updater_task'):
- self._extracted_updater_task.cancel()
- if hasattr(self, '_compressed_updater_task'):
- self._compressed_updater_task.cancel()
\ No newline at end of file
diff --git a/bitmind/validator/cache/image_cache.py b/bitmind/validator/cache/image_cache.py
deleted file mode 100644
index 1150484d..00000000
--- a/bitmind/validator/cache/image_cache.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import os
-import json
-import random
-from pathlib import Path
-from typing import Dict, List, Optional, Union, Any
-
-import bittensor as bt
-from PIL import Image
-
-from .base_cache import BaseCache
-from .extract import extract_images_from_parquet
-from .util import is_parquet_complete
-
-
-class ImageCache(BaseCache):
- """
- A class to manage image caching from parquet files.
-
- This class handles the caching, updating, and sampling of images stored
- in parquet files. It maintains both a compressed cache of parquet files
- and an extracted cache of images ready for processing.
- """
-
- def __init__(
- self,
- cache_dir: Union[str, Path],
- datasets: Optional[dict] = None,
- parquet_update_interval: int = 6,
- image_update_interval: int = 1,
- num_parquets_per_dataset: int = 5,
- num_images_per_source: int = 100,
- max_compressed_size_gb: int = 100,
- max_extracted_size_gb: int = 10
- ) -> None:
- """
- Args:
- cache_dir: Path to store extracted images
- parquet_update_interval: Hours between parquet cache updates
- image_update_interval: Hours between image cache updates
- num_images_per_source: Number of images to extract per parquet
- """
- super().__init__(
- cache_dir=cache_dir,
- datasets=datasets,
- extracted_update_interval=image_update_interval,
- compressed_update_interval=parquet_update_interval,
- num_sources_per_dataset=num_parquets_per_dataset,
- file_extensions=['.jpg', '.jpeg', '.png'],
- compressed_file_extension='.parquet',
- max_compressed_size_gb=max_compressed_size_gb,
- max_extracted_size_gb=max_extracted_size_gb
- )
- self.num_images_per_source = num_images_per_source
-
- def _clear_incomplete_sources(self) -> None:
- """Remove any incomplete or corrupted parquet files."""
- for path in self._get_compressed_files():
- if path.suffix == '.parquet' and not is_parquet_complete(path):
- try:
- path.unlink()
- bt.logging.warning(f"Removed incomplete parquet file {path}")
- except Exception as e:
- bt.logging.error(f"Error removing incomplete parquet {path}: {e}")
-
- def _extract_random_items(self, n_items_per_source: Optional[int] = None) -> List[Path]:
- """
- Extract random videos from zip files in compressed directory.
-
- Returns:
- List of paths to extracted video files.
- """
- if n_items_per_source is None:
- n_items_per_source = self.num_images_per_source
-
- extracted_files = []
- parquet_paths = self._get_compressed_files()
- if not parquet_paths:
- bt.logging.warning(f"[{self.compressed_dir}] No parquet files found")
- return extracted_files
-
- for parquet_path in parquet_paths:
- dataset = Path(parquet_path).relative_to(self.compressed_dir).parts[0]
- try:
- extracted_files += extract_images_from_parquet(
- parquet_path,
- self.cache_dir / dataset,
- n_items_per_source
- )
- except Exception as e:
- bt.logging.error(f"Error processing parquet file {parquet_path}: {e}")
- return extracted_files
-
- def sample(self, remove_from_cache=False) -> Optional[Dict[str, Any]]:
- """
- Sample a random image and its metadata from the cache.
-
- Returns:
- Dictionary containing:
- - image: PIL Image
- - path: Path to source file
- - dataset: Source dataset name
- - metadata: Metadata dict
- Returns None if no valid image is available.
- """
- cached_files = self._get_cached_files(group_by_source=True)
- if not cached_files:
- bt.logging.warning(f"[{self.cache_dir}] No images available in cache")
- return None
-
- attempts = 0
- max_attempts = len(cached_files) * 2
-
- while attempts < max_attempts:
- attempts += 1
- source = random.choice(list(cached_files.keys()))
- image_path = random.choice(cached_files[source])
-
- try:
- image = Image.open(image_path)
- metadata = json.loads(image_path.with_suffix('.json').read_text())
- if remove_from_cache:
- try:
- os.remove(image_path)
- os.remove(image_path.with_suffix('.json'))
- except Exception as e:
- bt.logging.warning(f"[{self.cache_dir}] Failed to remove files for {image_path}: {e}")
- return {
- 'image': image,
- 'path': str(image_path),
- 'dataset': metadata.get('dataset', str(Path(image_path).parent.name),),
- 'index': metadata.get('index', None),
- 'mask_center': metadata.get('mask_center', None)
- }
-
- except Exception as e:
- bt.logging.warning(f"Failed to load image {image_path}: {e}")
- continue
-
- bt.logging.warning(f"Failed to find valid image after {attempts} attempts")
- return None
diff --git a/bitmind/validator/cache/util.py b/bitmind/validator/cache/util.py
deleted file mode 100644
index d429db48..00000000
--- a/bitmind/validator/cache/util.py
+++ /dev/null
@@ -1,77 +0,0 @@
-from pathlib import Path
-from typing import Union, Callable
-from zipfile import ZipFile, BadZipFile
-from enum import Enum, auto
-import asyncio
-import pyarrow.parquet as pq
-import bittensor as bt
-
-
-def seconds_to_str(seconds):
- seconds = int(float(seconds))
- hours = seconds // 3600
- minutes = (seconds % 3600) // 60
- seconds = seconds % 60
- return f"{hours:02}:{minutes:02}:{seconds:02}"
-
-
-def get_most_recent_update_time(directory: Path) -> float:
- """Get the most recent modification time of any file in directory."""
- try:
- mtimes = [f.stat().st_mtime for f in directory.iterdir()]
- return max(mtimes) if mtimes else 0
- except Exception as e:
- bt.logging.error(f"Error getting modification times: {e}")
- return 0
-
-
-class FileType(Enum):
- PARQUET = auto()
- ZIP = auto()
-
-
-def get_integrity_check(file_type: FileType) -> Callable[[Path], bool]:
- """Returns the appropriate validation function for the file type."""
- if file_type == FileType.PARQUET:
- return is_parquet_complete
- elif file_type == FileType.ZIP:
- return is_zip_complete
- raise ValueError(f"Unsupported file type: {file_type}")
-
-
-def is_zip_complete(zip_path: Union[str, Path], testzip=False) -> bool:
- """
- Args:
- zip_path: Path to zip file
- testzip: More thorough, less efficient
- Returns:
- bool: True if zip is valid, False otherwise
- """
- try:
- with ZipFile(zip_path) as zf:
- if testzip:
- zf.testzip()
- else:
- zf.namelist()
- return True
- except (BadZipFile, Exception) as e:
- bt.logging.error(f"Zip file {zip_path} is invalid: {e}")
- return False
-
-
-def is_parquet_complete(path: Path) -> bool:
- """
- Args:
- path: Path to the parquet file
-
- Returns:
- bool: True if file is valid, False otherwise
- """
- try:
- with open(path, 'rb') as f:
- pq.read_metadata(f)
- return True
- except Exception as e:
- bt.logging.error(f"Parquet file {path} is incomplete or corrupted: {e}")
- return False
-
diff --git a/bitmind/validator/cache/video_cache.py b/bitmind/validator/cache/video_cache.py
deleted file mode 100644
index 12089539..00000000
--- a/bitmind/validator/cache/video_cache.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import os
-import random
-from io import BytesIO
-from pathlib import Path
-from typing import Dict, List, Optional, Union
-
-import bittensor as bt
-import ffmpeg
-from PIL import Image
-
-from .base_cache import BaseCache
-from .extract import extract_videos_from_zip
-from .util import is_zip_complete
-from bitmind.validator.video_utils import get_video_duration
-
-
-class VideoCache(BaseCache):
- """
- A class to manage video caching and processing operations.
-
- This class handles the caching, updating, and sampling of video files from
- compressed archives and optionally YouTube. It maintains both a compressed
- cache of source files and an extracted cache of video files ready for processing.
- """
-
- def __init__(
- self,
- cache_dir: Union[str, Path],
- datasets: Optional[dict] = None,
- video_update_interval: int = 1,
- zip_update_interval: int = 6,
- num_zips_per_dataset: int = 1,
- num_videos_per_zip: int = 10,
- max_compressed_size_gb: int = 100,
- max_extracted_size_gb: int = 10
- ) -> None:
- """
- Initialize the VideoCache.
-
- Args:
- cache_dir: Path to store extracted video files
- video_update_interval: Hours between video cache updates
- zip_update_interval: Hours between zip cache updates
- num_videos_per_source: Number of videos to extract per source
- use_youtube: Whether to include YouTube videos
- """
- super().__init__(
- cache_dir=cache_dir,
- datasets=datasets,
- extracted_update_interval=video_update_interval,
- compressed_update_interval=zip_update_interval,
- num_sources_per_dataset=num_zips_per_dataset,
- file_extensions=['.mp4', '.avi', '.mov', '.mkv'],
- compressed_file_extension='.zip',
- max_compressed_size_gb=max_compressed_size_gb,
- max_extracted_size_gb=max_extracted_size_gb
- )
- self.num_videos_per_zip = num_videos_per_zip
-
- def _clear_incomplete_sources(self) -> None:
- """Remove any incomplete or corrupted zip files from cache."""
- for path in self._get_compressed_files():
- if path.suffix == '.zip' and not is_zip_complete(path):
- try:
- path.unlink()
- bt.logging.warning(f"Removed incomplete zip file {path}")
- except Exception as e:
- bt.logging.error(f"Error removing incomplete zip {path}: {e}")
-
- def _extract_random_items(self, n_items_per_source: Optional[int] = None) -> List[Path]:
- """
- Extract random videos from zip files in compressed directory.
-
- Returns:
- List of paths to extracted video files.
- """
- if n_items_per_source is None:
- n_items_per_source = self.num_videos_per_zip
-
- extracted_files = []
- zip_paths = self._get_compressed_files()
- if not zip_paths:
- bt.logging.warning(f"[{self.compressed_dir}] No zip files found")
- return extracted_files
-
- for zip_path in zip_paths:
- dataset = Path(zip_path).relative_to(self.compressed_dir).parts[0]
- try:
- extracted_files += extract_videos_from_zip(
- zip_path,
- self.cache_dir / dataset,
- n_items_per_source)
- except Exception as e:
- bt.logging.error(f"[{self.compressed_dir}] Error processing zip file {zip_path}: {e}")
-
- return extracted_files
-
- def sample(
- self,
- num_frames: int = 6,
- fps: Optional[float] = None,
- min_fps: Optional[float] = None,
- max_fps: Optional[float] = None,
- remove_from_cache: bool = False
- ) -> Optional[Dict[str, Union[List[Image.Image], str, float]]]:
- """
- Sample random frames from a random video in the cache.
-
- Args:
- num_frames: Number of consecutive frames to sample
- fps: Fixed frames per second to sample. Mutually exclusive with min_fps/max_fps.
- min_fps: Minimum frames per second when auto-calculating fps. Must be used with max_fps.
- max_fps: Maximum frames per second when auto-calculating fps. Must be used with min_fps.
-
- Returns:
- Dictionary containing:
- - video: List of sampled video frames as PIL Images
- - path: Path to source video file
- - dataset: Name of source dataset
- - total_duration: Total video duration in seconds
- - sampled_length: Number of seconds sampled
- Returns None if no videos are available or extraction fails.
- """
- if fps is not None and (min_fps is not None or max_fps is not None):
- raise ValueError("Cannot specify both fps and min_fps/max_fps")
- if (min_fps is None) != (max_fps is None):
- raise ValueError("min_fps and max_fps must be specified together")
-
- video_files = self._get_cached_files(group_by_source=True)
- if not video_files:
- bt.logging.warning("No videos available in cache")
- return None
-
- source = random.choice(list(video_files.keys()))
- video_path = random.choice(video_files[source])
- if not Path(video_path).exists():
- bt.logging.error(f"Selected video {video_path} not found")
- return None
-
- try:
- duration = get_video_duration(str(video_path))
- except Exception as e:
- bt.logging.error(f"Unable to extract video duration from {str(video_path)}")
- return None
-
- # Use fixed fps if provided, otherwise calculate from range
- frame_rate = fps
- if frame_rate is None:
- # For very short videos (< 1 second), use max_fps to capture detail
- if duration <= 1.0:
- frame_rate = max_fps
- else:
- # For longer videos, scale fps inversely with duration
- # This ensures we don't span too much of longer videos
- # while still capturing enough detail in shorter ones
- target_duration = min(2.0, duration * 0.2) # Cap at 2 seconds or 20% of duration
- frame_rate = (num_frames - 1) / target_duration
- frame_rate = max(min_fps, min(frame_rate, max_fps))
-
- sample_duration = (num_frames - 1) / frame_rate
- start_time = random.uniform(0, max(0, duration - sample_duration))
- frames: List[Image.Image] = []
-
- no_data = []
- for i in range(num_frames):
- timestamp = start_time + (i / frame_rate)
-
- try:
- # extract frames
- out_bytes, err = (
- ffmpeg
- .input(str(video_path), ss=str(timestamp))
- .filter('select', 'eq(n,0)')
- .output(
- 'pipe:',
- vframes=1,
- format='image2',
- vcodec='png',
- loglevel='error' # silence ffmpeg output
- )
- .run(capture_stdout=True, capture_stderr=True)
- )
-
- if not out_bytes:
- no_data.append(timestamp)
- continue
-
- try:
- frame = Image.open(BytesIO(out_bytes))
- frame.load() # Verify image can be loaded
- frames.append(frame)
- bt.logging.debug(f'Successfully extracted frame at {timestamp}s')
- except Exception as e:
- bt.logging.error(f'Failed to process frame at {timestamp}s: {e}')
- continue
-
- except ffmpeg.Error as e:
- bt.logging.error(f'FFmpeg error at {timestamp}s: {e.stderr.decode()}')
- continue
-
- if len(no_data) > 0:
- tmin, tmax = min(no_data), max(no_data)
- bt.logging.warning(f'No data received for {len(no_data)} frames between {tmin} and {tmax}')
-
- if remove_from_cache:
- try:
- os.remove(video_path)
- os.remove(video_path.with_suffix('.json'))
- except Exception as e:
- bt.logging.warning(f"Failed to remove files for {video_path}: {e}")
-
- bt.logging.success(f"Sampled {len(frames)} frames at {frame_rate}fps")
- return {
- 'video': frames,
- 'fps': frame_rate,
- 'num_frames': num_frames,
- 'path': str(video_path),
- 'dataset': str(Path(video_path).parent.name),
- 'total_duration': duration,
- 'sampled_length': sample_duration
- }
diff --git a/bitmind/validator/challenge.py b/bitmind/validator/challenge.py
deleted file mode 100644
index 9cca7724..00000000
--- a/bitmind/validator/challenge.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import random
-from typing import Dict, List, Any, Tuple, Union
-from dataclasses import dataclass, field
-from PIL import Image
-
-import numpy as np
-import pandas as pd
-import wandb
-import bittensor as bt
-
-from bitmind.protocol import prepare_synapse
-from bitmind.utils.image_transforms import apply_augmentation_by_level
-from bitmind.utils.uids import get_random_uids
-from bitmind.validator.reward import get_rewards
-from bitmind.validator.config import (
- TARGET_IMAGE_SIZE,
- MIN_FRAMES,
- MAX_FRAMES,
- P_STITCH,
- LABELS,
- LABEL_TO_TYPE,
- LABEL_PROBS,
- Modality,
- MODALITY_PROBS,
-)
-
-
-@dataclass
-class ChallengeConfig:
- """Configuration parameters for challenge generation."""
- target_image_size: Tuple[int] = TARGET_IMAGE_SIZE
- modality_options: Tuple[int] = field(default_factory=lambda: [m.value for m in Modality])
- modality_probs: List[float] = MODALITY_PROBS
- label_options: Tuple[int] = LABELS
- label_probs: Tuple[float] = LABEL_PROBS
- label_to_type: Dict[int, str] = field(default_factory=lambda: LABEL_TO_TYPE)
- min_frames: int = MIN_FRAMES
- max_frames: int = MAX_FRAMES
- min_fps: int = 8
- max_fps: int = 30
- stitch_prob: float = P_STITCH
-
-
-@dataclass
-class Challenge:
- """
- Container for challenge data and metadata.
-
- A challenge consists of either an image or video that needs to be classified
- as real, synthetic, or semisynthetic. The class manages the challenge lifecycle
- including creation, data processing, and metadata handling.
-
- Attributes:
- label (int): Label value (0=real, 1=synthetic, 2=semisynthetic)
- media_type (str): Type of media ('real', 'synthetic', 'semisynthetic')
- modality (str): Media modality ('image' or 'video')
- original_media (Union[Image.Image, List[Image.Image], None]): The actual image or
- video frames
- original_media (Union[Image.Image, List[Image.Image], None]): The actual image or
- video frames with transformations and augmentations applied
- metadata (Dict[str, Any]): Additional information about the challenge
- config (ChallengeConfig): Configuration parameters for challenge generation
- """
- label: int = -1
- media_type: str = ""
- modality: str = ""
- original_media: Union[Image.Image, List[Image.Image], None] = None
- augmented_media: Union[Image.Image, List[Image.Image], None] = None
- metadata: Dict[str, Any] = field(default_factory=dict)
-
- config: ChallengeConfig = field(default_factory=ChallengeConfig)
-
- @classmethod
- def create(cls, media_cache):
- """Factory method to create and initialize a challenge."""
- challenge = cls()
-
- challenge.label = np.random.choice(
- challenge.config.label_options,
- p=challenge.config.label_probs
- )
- challenge.media_type = challenge.config.label_to_type[challenge.label]
- challenge.modality = np.random.choice(
- challenge.config.modality_options,
- p=challenge.config.modality_probs
- )
-
- bt.logging.info(f"Sampling data from {challenge.modality} cache")
- cache = media_cache[challenge.modality][challenge.media_type]
-
- if challenge.modality == 'video':
- sample = challenge.sample_video_frames(cache)
- elif challenge.modality == 'image':
- sample = cache.sample()
-
- if sample is None:
- bt.logging.warning(f"Waiting for {challenge.media_type} cache to populate. Challenge skipped.")
- return None
-
- challenge.original_media = sample[challenge.modality]
- try:
- challenge.augmented_media, aug_level, aug_params = apply_augmentation_by_level(
- challenge.original_media,
- challenge.config.target_image_size,
- sample.get('mask_center', None))
- except Exception as e:
- bt.logging.error(f"Unable to apply augmentations: {e}\nChallenge generation failed.")
- return None
-
- sample.update({'aug_params': aug_params, 'aug_level': aug_level})
- if not challenge.process_metadata(sample):
- bt.logging.warning(f"Failed to process metadata. Challenge skipped.")
- return None
-
- return challenge
-
- def sample_video_frames(self, video_cache):
- """Sample frames from the video cache, either as a single clip or two combined clips."""
- min_frames = self.config.min_frames
- max_frames = self.config.max_frames
- min_fps = self.config.min_fps
- max_fps = self.config.max_fps
-
- if np.random.rand() > self.config.stitch_prob:
- num_frames = random.randint(min_frames, max_frames)
- sample = video_cache.sample(num_frames, min_fps=min_fps, max_fps=max_fps)
- else:
- num_frames_A = random.randint(min_frames, max_frames - 1)
- sample_A = video_cache.sample(num_frames_A, min_fps=min_fps, max_fps=max_fps)
- if sample_A is None:
- return None
- num_frames_B = random.randint(min_frames, max(max_frames - num_frames_A, min_frames + 1))
- sample_B = video_cache.sample(num_frames_B, fps=sample_A['fps'])
- sample = {k + '_A': v for k, v in sample_A.items()}
- sample.update({k + '_B': v for k, v in sample_B.items()})
- sample['video'] = sample_A['video'] + sample_B['video']
-
- return sample
-
- def process_metadata(self, sample) -> bool:
- """Prepare challenge metadata and media for logging to Weights & Biases """
- self.metadata = {
- 'label': int(self.label),
- 'media_type': str(self.media_type),
- 'modality': str(self.modality)
- }
- self.metadata.update({
- k: v for k, v in sample.items()
- if self.modality not in k
- })
- try:
- if self.modality == 'video':
- def create_wandb_video(video_frames, fps):
- frames = [np.array(img) for img in video_frames]
- frames_arr = np.stack(frames, axis=0)
- if frames_arr.min() >= 0 and frames_arr.max() <= 1:
- frames_arr = (frames_arr * 255).astype(np.uint8)
-
- if frames_arr.shape[1] != 3:
- frames_arr = frames_arr.transpose(0, 3, 1, 2)
-
- return wandb.Video(frames_arr, format="mp4", fps=fps)
-
- if 'video_A' in sample:
- self.metadata['video_A'] = create_wandb_video(sample['video_A'], sample['fps_A'])
- self.metadata['video_B'] = create_wandb_video(sample['video_B'], sample['fps_B'])
- else:
- self.metadata['video'] = create_wandb_video(self.original_media, self.metadata.get('fps', 30))
-
- self.metadata['augmented_video'] = create_wandb_video(
- self.augmented_media, self.metadata.get('fps', 30))
-
- elif self.modality == 'image':
- self.metadata['image'] = wandb.Image(self.original_media)
- self.metadata['augmented_image'] = wandb.Image(self.augmented_media)
-
- return True
-
- except Exception as e:
- bt.logging.error(e)
- bt.logging.error(f"{self.modality} is truncated or corrupt. Challenge skipped.")
- return False
\ No newline at end of file
diff --git a/bitmind/validator/config.py b/bitmind/validator/config.py
deleted file mode 100644
index 08d04842..00000000
--- a/bitmind/validator/config.py
+++ /dev/null
@@ -1,500 +0,0 @@
-from strenum import StrEnum
-from pathlib import Path
-from typing import Dict, List, Union, Optional, Any
-
-import numpy as np
-import torch
-from diffusers import (
- StableDiffusionPipeline,
- StableDiffusionXLPipeline,
- FluxPipeline,
- CogVideoXPipeline,
- MochiPipeline,
- HunyuanVideoPipeline,
- AnimateDiffPipeline,
- IFPipeline,
- IFSuperResolutionPipeline,
- EulerDiscreteScheduler,
- DEISMultistepScheduler,
- AutoPipelineForInpainting,
- StableDiffusionInpaintPipeline,
- CogView4Pipeline,
- CogVideoXImageToVideoPipeline
-)
-
-from .model_utils import (
- load_annimatediff_motion_adapter,
- load_hunyuanvideo_transformer,
- JanusWrapper
-)
-
-
-TARGET_IMAGE_SIZE: tuple[int, int] = (256, 256)
-
-MAINNET_UID = 34
-TESTNET_UID = 168
-
-# Project constants
-MAINNET_WANDB_PROJECT: str = 'bitmind-subnet'
-TESTNET_WANDB_PROJECT: str = 'bitmind'
-WANDB_ENTITY: str = 'bitmindai'
-
-
-# Enums
-class MediaType(StrEnum):
- REAL = "real"
- SYNTHETIC = "synthetic"
- SEMISYNTHETIC = "semisynthetic"
-
-
-class Modality(StrEnum):
- IMAGE = "image"
- VIDEO = "video"
-
-
-# Cache directories
-HUGGINGFACE_CACHE_DIR: Path = Path.home() / '.cache' / 'huggingface'
-SN34_CACHE_DIR: Path = Path.home() / '.cache' / 'sn34'
-SN34_CACHE_DIR.mkdir(parents=True, exist_ok=True)
-
-VALIDATOR_INFO_PATH: Path = SN34_CACHE_DIR / 'validator.yaml'
-
-IMAGE_CACHE_DIR: Path = SN34_CACHE_DIR / Modality.IMAGE
-VIDEO_CACHE_DIR: Path = SN34_CACHE_DIR / Modality.VIDEO
-
-REAL_IMAGE_CACHE_DIR: Path = IMAGE_CACHE_DIR / MediaType.REAL
-SYNTH_IMAGE_CACHE_DIR: Path = IMAGE_CACHE_DIR / MediaType.SYNTHETIC
-SEMISYNTH_IMAGE_CACHE_DIR: Path = IMAGE_CACHE_DIR / MediaType.SEMISYNTHETIC
-
-REAL_VIDEO_CACHE_DIR: Path = VIDEO_CACHE_DIR / MediaType.REAL
-SYNTH_VIDEO_CACHE_DIR: Path = VIDEO_CACHE_DIR / MediaType.SYNTHETIC
-SEMISYNTH_VIDEO_CACHE_DIR: Path = VIDEO_CACHE_DIR / MediaType.SEMISYNTHETIC
-
-LABELS = (0, 1, 2)
-LABEL_TO_TYPE = {
- 0: MediaType.REAL,
- 1: MediaType.SYNTHETIC,
- 2: MediaType.SEMISYNTHETIC
-}
-
-P_REAL: float = 0.5
-P_SYNTH: float = 0.4
-P_SEMISYNTH: float = 0.1
-LABEL_PROBS: List[float] = (P_REAL, P_SYNTH, P_SEMISYNTH)
-
-MODALITY_PROBS = (0.5, 0.5)
-
-# Probability of concatenating together two videos
-# Will only ever combine videos of the same type
-# i.e. real + real, synth + synth, semisynth + semisynth
-P_STITCH: float = 0.2
-
-# Number of frames in challenge
-MIN_FRAMES = 8
-MAX_FRAMES = 129
-
-# Update intervals in hours
-VIDEO_ZIP_CACHE_UPDATE_INTERVAL = 2
-IMAGE_PARQUET_CACHE_UPDATE_INTERVAL = 2
-VIDEO_CACHE_UPDATE_INTERVAL = 1
-IMAGE_CACHE_UPDATE_INTERVAL = 1
-
-MAX_COMPRESSED_GB = 50
-MAX_EXTRACTED_GB = 5
-
-
-# dataset configurations
-IMAGE_DATASETS = {
- "real": [
- {"path": "bitmind/bm-eidon-image"},
- {"path": "bitmind/bm-real"},
- {"path": "bitmind/open-image-v7-256"},
- {"path": "bitmind/celeb-a-hq"},
- {"path": "bitmind/ffhq-256"},
- {"path": "bitmind/MS-COCO-unique-256"},
- {"path": "bitmind/AFHQ"},
- {"path": "bitmind/lfw"},
- {"path": "bitmind/caltech-256"},
- {"path": "bitmind/caltech-101"},
- {"path": "bitmind/dtd"},
- {"path": "bitmind/idoc-mugshots-images"}
- ],
- "semisynthetic": [
- {"path": "bitmind/face-swap"}
- ],
- "synthetic": [
- {"path": "bitmind/JourneyDB"},
- {"path": "bitmind/GenImage_MidJourney"}
- ]
-}
-
-VIDEO_DATASETS = {
- "real": [
- {"path": "bitmind/bm-eidon-video", "filetype": "zip"},
- {"path": "shangxd/imagenet-vidvrd", "filetype": "zip"},
- {"path": "nkp37/OpenVid-1M", "filetype": "zip"}
- ],
- "semisynthetic": [
- {"path": "bitmind/semisynthetic-video", "filetype": "zip"}
- ]
-}
-
-
-# Prompt generation model configurations
-IMAGE_ANNOTATION_MODEL: str = "Salesforce/blip2-opt-6.7b-coco"
-TEXT_MODERATION_MODEL: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
-
-# Text-to-image model configurations
-T2I_MODELS: Dict[str, Dict[str, Any]] = {
- "THUDM/CogView4-6B": {
- "pipeline_cls": CogView4Pipeline,
- "from_pretrained_args": {
- "torch_dtype": torch.bfloat16,
- "use_safetensors": True
- },
- "generate_args": {
- "guidance_scale": 3.5,
- "num_images_per_prompt": 1,
- "num_inference_steps": 50,
- "width": 512,
- "height": 512
- },
- "use_autocast": False
- },
- "stabilityai/stable-diffusion-xl-base-1.0": {
- "pipeline_cls": StableDiffusionXLPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.float16,
- "variant": "fp16"
- },
- "use_autocast": False
- },
- "SG161222/RealVisXL_V4.0": {
- "pipeline_cls": StableDiffusionXLPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.float16,
- "variant": "fp16"
- }
- },
- "Corcelio/mobius": {
- "pipeline_cls": StableDiffusionXLPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.float16
- }
- },
- "black-forest-labs/FLUX.1-dev": {
- "pipeline_cls": FluxPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.bfloat16,
- },
- "generate_args": {
- "guidance_scale": 2,
- "num_inference_steps": {"min": 50, "max": 125},
- "generator": torch.Generator("cuda" if torch.cuda.is_available() else "cpu"),
- "resolution": [512, 768]
- },
- "enable_model_cpu_offload": False
- },
- "runwayml/stable-diffusion-v1-5-midjourney-v6": {
- "pipeline_cls": StableDiffusionPipeline,
- "from_pretrained_args": {
- "model_id": "runwayml/stable-diffusion-v1-5",
- "use_safetensors": True,
- "torch_dtype": torch.float16,
- },
- "lora_model_id": "Kvikontent/midjourney-v6",
- "lora_loading_args": {
- "use_peft_backend": True
- },
- "enable_model_cpu_offload": False
- },
- "prompthero/openjourney-v4" : {
- "pipeline_cls": StableDiffusionPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.float16,
- }
- },
- "cagliostrolab/animagine-xl-3.1": {
- "pipeline_cls": StableDiffusionXLPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.float16,
- },
- },
- "DeepFloyd/IF": {
- "pipeline_cls": {
- "stage1": IFPipeline,
- "stage2": IFSuperResolutionPipeline
- },
- "from_pretrained_args": {
- "stage1": {
- "base": "DeepFloyd/IF-I-XL-v1.0",
- "torch_dtype": torch.float16,
- "variant": "fp16",
- "clean_caption": False,
- "watermarker": None,
- "requires_safety_checker": False
- },
- "stage2": {
- "base": "DeepFloyd/IF-II-L-v1.0",
- "torch_dtype": torch.float16,
- "variant": "fp16",
- "text_encoder": None,
- "watermarker": None,
- "requires_safety_checker": False
- }
- },
- "pipeline_stages": [
- {
- "name": "stage1",
- "args": {
- "output_type": "pt",
- "num_images_per_prompt": 1,
- "return_dict": True
- },
- "output_attr": "images",
- "output_transform": lambda x: x[0].unsqueeze(0),
- "save_prompt_embeds": True
- },
- {
- "name": "stage2",
- "input_key": "image",
- "args": {
- "output_type": "pil",
- "num_images_per_prompt": 1
- },
- "output_attr": "images",
- "use_prompt_embeds": True
- }
- ],
- "clear_memory_on_stage_end": True
- },
- "deepseek-ai/Janus-Pro-7B": {
- "pipeline_cls": JanusWrapper,
- "from_pretrained_args": {
- "torch_dtype": torch.bfloat16,
- "use_safetensors": True,
- },
- "generate_args": {
- "temperature": 1.0,
- "parallel_size": 4,
- "cfg_weight": 5.0,
- "image_token_num_per_image": 576,
- "img_size": 384,
- "patch_size": 16
- },
- "use_autocast": False,
- "enable_model_cpu_offload": False
- },
-}
-T2I_MODEL_NAMES: List[str] = list(T2I_MODELS.keys())
-
-# Image-to-image model configurations
-I2I_MODELS: Dict[str, Dict[str, Any]] = {
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1": {
- "pipeline_cls": AutoPipelineForInpainting,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.float16,
- "variant": "fp16"
- },
- "generate_args": {
- "guidance_scale": 7.5,
- "num_inference_steps": 50,
- "strength": 0.99,
- "generator": torch.Generator("cuda" if torch.cuda.is_available() else "cpu"),
- }
- },
- "Lykon/dreamshaper-8-inpainting": {
- "pipeline_cls": AutoPipelineForInpainting,
- "from_pretrained_args": {
- "torch_dtype": torch.float16,
- "variant": "fp16"
- },
- "generate_args": {
- "num_inference_steps": {"min": 40, "max": 60},
- },
- "scheduler": {
- "cls": DEISMultistepScheduler
- }
- }
-}
-I2I_MODEL_NAMES: List[str] = list(I2I_MODELS.keys())
-
-# Text-to-video model configurations
-T2V_MODELS: Dict[str, Dict[str, Any]] = {
- "tencent/HunyuanVideo": {
- "pipeline_cls": HunyuanVideoPipeline,
- "from_pretrained_args": {
- "model_id": "tencent/HunyuanVideo",
- "transformer": ( # custom functions supplied as tuple of (fn, args)
- load_hunyuanvideo_transformer,
- {
- "model_id": "tencent/HunyuanVideo",
- "subfolder": "transformer",
- "torch_dtype": torch.bfloat16,
- "revision": 'refs/pr/18'
- }
- ),
- "revision": 'refs/pr/18',
- "torch_dtype": torch.bfloat16
- },
- "generate_args": {
- "num_frames": {"min": 61, "max": 129},
- "resolution": {"options": [
- [720, 1280], [1280, 720], [1104, 832], [832,1104], [960,960],
- [544, 960], [960, 544], [624, 832], [832, 624], [720, 720]
- ]},
- "num_inference_steps": {"min": 30, "max": 50},
- },
- "save_args": {"fps": 30},
- "use_autocast": False,
- "vae_enable_tiling": True
- },
- "genmo/mochi-1-preview": {
- "pipeline_cls": MochiPipeline,
- "from_pretrained_args": {
- "variant": "bf16",
- "torch_dtype": torch.bfloat16
- },
- "generate_args": {
- "num_frames": 84,
- "num_inference_steps": {"min": 30, "max": 65},
- "resolution": [480, 848]
- },
- "save_args": {"fps": 30},
- "vae_enable_tiling": True
- },
- 'THUDM/CogVideoX-5b': {
- "pipeline_cls": CogVideoXPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.bfloat16
- },
- "generate_args": {
- "guidance_scale": 2,
- "num_videos_per_prompt": 1,
- "num_inference_steps": {"min": 50, "max": 125},
- "num_frames": 48
- },
- "save_args": {"fps": 8},
- "enable_model_cpu_offload": True,
- #"enable_sequential_cpu_offload": True,
- "vae_enable_slicing": True,
- "vae_enable_tiling": True
- },
- 'ByteDance/AnimateDiff-Lightning': {
- "pipeline_cls": AnimateDiffPipeline,
- "from_pretrained_args": {
- "model_id": "emilianJR/epiCRealism",
- "torch_dtype": torch.bfloat16,
- "motion_adapter": (
- load_annimatediff_motion_adapter,
- {"step": 4}
- )
- },
- "generate_args": {
- "guidance_scale": 2,
- "num_inference_steps": {"min": 50, "max": 125},
- "resolution": {"options": [
- [512, 512], [512, 768], [512, 1024],
- [768, 512], [768, 768], [768, 1024],
- [1024, 512], [1024, 768], [1024, 1024]
- ]}
- },
- "save_args": {"fps": 15},
- "scheduler": {
- "cls": EulerDiscreteScheduler,
- "from_config_args": {
- "timestep_spacing": "trailing",
- "beta_schedule": "linear"
- }
- }
- }
-}
-T2V_MODEL_NAMES: List[str] = list(T2V_MODELS.keys())
-
-# Image-to-video model configurations
-I2V_MODELS: Dict[str, Dict[str, Any]] = {
- "THUDM/CogVideoX1.5-5B-I2V": {
- "pipeline_cls": CogVideoXImageToVideoPipeline,
- "from_pretrained_args": {
- "use_safetensors": True,
- "torch_dtype": torch.bfloat16
- },
- "generate_args": {
- "guidance_scale": 2,
- "num_videos_per_prompt": 1,
- "num_inference_steps": {"min": 50, "max": 125},
- "num_frames": 49,
- "height": 768,
- "width": 768,
- },
- "save_args": {"fps": 8},
- "enable_model_cpu_offload": True,
- "vae_enable_slicing": True,
- "vae_enable_tiling": True
- }
-}
-I2V_MODEL_NAMES: List[str] = list(I2V_MODELS.keys())
-
-# Combined model configurations
-MODELS: Dict[str, Dict[str, Any]] = {**T2I_MODELS, **I2I_MODELS, **T2V_MODELS, **I2V_MODELS}
-MODEL_NAMES: List[str] = list(MODELS.keys())
-
-def get_modality(model_name):
- if model_name in T2V_MODEL_NAMES + I2V_MODEL_NAMES:
- return Modality.VIDEO
- elif model_name in T2I_MODEL_NAMES + I2I_MODEL_NAMES:
- return Modality.IMAGE
-
-def get_output_media_type(model_name):
- if model_name in I2I_MODEL_NAMES:
- return MediaType.SEMISYNTHETIC
- elif model_name in T2I_MODEL_NAMES + T2V_MODEL_NAMES + I2V_MODEL_NAMES:
- return MediaType.SYNTHETIC
-
-def get_task(model_name):
- if model_name in T2V_MODEL_NAMES:
- return 't2v'
- elif model_name in T2I_MODEL_NAMES:
- return 't2i'
- elif model_name in I2I_MODEL_NAMES:
- return 'i2i'
- elif model_name in I2V_MODEL_NAMES:
- return 'i2v'
-
-
-def select_random_model(task: Optional[str] = None) -> str:
- """
- Select a random text-to-image, text-to-video, image-to-image, or image-to-video model based on the specified
- modality.
-
- Args:
- modality: The type of model to select ('t2v', 't2i', 'i2i', 'i2v', or 'random').
- If None or 'random', randomly chooses between the valid options
-
- Returns:
- The name of the selected model.
-
- Raises:
- NotImplementedError: If the specified modality is not supported.
- """
- if task is None or task == 'random':
- task = np.random.choice(['t2i', 'i2i', 't2v', 'i2v'])
-
- if task == 't2i':
- return np.random.choice(T2I_MODEL_NAMES)
- elif task == 't2v':
- return np.random.choice(T2V_MODEL_NAMES)
- elif task == 'i2i':
- return np.random.choice(I2I_MODEL_NAMES)
- elif task == 'i2v':
- return np.random.choice(I2V_MODEL_NAMES)
- else:
- raise NotImplementedError(f"Unsupported task: {task}")
\ No newline at end of file
diff --git a/bitmind/validator/forward.py b/bitmind/validator/forward.py
deleted file mode 100644
index 8e29fcb2..00000000
--- a/bitmind/validator/forward.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# developer: dubm
-# Copyright © 2023 BitMind
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import time
-import wandb
-import bittensor as bt
-
-from bitmind.protocol import prepare_synapse
-from bitmind.utils.uids import get_random_uids
-from bitmind.validator.reward import get_rewards
-from bitmind.validator.config import MAINNET_UID
-from bitmind.validator.challenge import Challenge
-
-
-async def forward(self):
- """
- The forward function is called by the validator every time step.
- It is responsible for querying the network and scoring the responses.
-
- This implementation uses a Challenge class to encapsulate challenge data and configuration,
- with execution logic directly in the forward function.
- """
-
- # create challenge
- challenge = Challenge.create(self.media_cache)
- if challenge is None:
- return
-
- # sample miners
- miner_uids = get_random_uids(self, k=self.config.neuron.sample_size)
- axons = [self.metagraph.axons[uid] for uid in miner_uids]
- challenge.metadata['miner_uids'] = list(miner_uids)
- challenge.metadata['miner_hotkeys'] = list([axon.hotkey for axon in axons])
-
- # prepare synapse
- synapse = prepare_synapse(
- challenge.augmented_media,
- modality=challenge.modality)
-
- # on testnet, add label for eyeballing correctness
- if self.metagraph.netuid != MAINNET_UID:
- synapse.testnet_label = challenge.label
-
- # query miners
- bt.logging.info(f"Sending {challenge.modality} challenge to {len(miner_uids)} miners")
- start = time.time()
- responses = await self.dendrite(
- axons=axons,
- synapse=synapse,
- deserialize=True,
- run_async=True,
- timeout=9
- )
- bt.logging.info(f"Responses received in {time.time() - start}s")
- bt.logging.success(f"{challenge.media_type} {challenge.modality} challenge complete!")
- bt.logging.info({
- k: v for k, v in challenge.metadata.items()
- if k not in ('miner_uids', 'miner_hotkeys')
- })
-
- # compute miner rewards and update score vector
- bt.logging.info(f"Scoring responses")
- rewards, metrics = get_rewards(
- label=challenge.label,
- responses=responses,
- uids=miner_uids,
- axons=axons,
- challenge_modality=challenge.modality,
- performance_trackers=self.performance_trackers)
-
- self.update_scores(rewards, miner_uids)
-
- # log results, track responding miners for serving organics
- responding_miner_uids = []
- unresponsive_miner_uids = []
- for uid, pred, reward, perf in zip(miner_uids, responses, rewards, metrics):
- if -1 in pred:
- unresponsive_miner_uids.append(uid)
- continue
- metric_str = ' | '.join([f"{modality} {m}: {perf[modality][m]:.4f}" for modality in perf for m in perf[modality]])
- bt.logging.success(f"UID: {uid} | {pred} | Reward: {reward:.4f} | " + metric_str)
- responding_miner_uids.append(uid)
-
- if len(unresponsive_miner_uids) > 0:
- bt.logging.warning(f"Failed to get responses from {len(unresponsive_miner_uids)} miners:")
- for uid in unresponsive_miner_uids:
- bt.logging.warning(f'UID {uid} ({self.metagraph.axons[uid]})')
-
- if responding_miner_uids:
- self.last_responding_miner_uids = responding_miner_uids
-
- # add predictions, rewards, scores and metrics to logging data
- challenge.metadata['predictions'] = responses
- challenge.metadata['rewards'] = rewards
- challenge.metadata['scores'] = list(self.scores)
- for modality in ['image', 'video']:
- if metrics and modality in metrics[0]:
- for metric_name in list(metrics[0][modality].keys()):
- challenge.metadata[f'miner_{modality}_{metric_name}'] = [
- m[modality][metric_name] for m in metrics
- ]
-
- if not self.config.wandb.off:
- wandb.log(challenge.metadata)
-
- self.save_miner_history()
- self.media_cache[challenge.modality][challenge.media_type].prune_cache('extracted')
diff --git a/bitmind/validator/miner_performance_tracker.py b/bitmind/validator/miner_performance_tracker.py
deleted file mode 100644
index 22e4b19c..00000000
--- a/bitmind/validator/miner_performance_tracker.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from sklearn.metrics import matthews_corrcoef
-from typing import Dict, List
-from collections import deque
-import bittensor as bt
-import numpy as np
-
-
-class MinerPerformanceTracker:
- """
- Tracks all recent miner performance to facilitate reward computation.
- """
- VERSION = 2
-
- def __init__(self, store_last_n_predictions: int = 100):
- self.prediction_history: Dict[int, deque] = {}
- self.label_history: Dict[int, deque] = {}
- self.miner_hotkeys: Dict[int, str] = {}
- self.store_last_n_predictions = store_last_n_predictions
- self.version = self.VERSION
-
- def reset_miner_history(self, uid: int, miner_hotkey: str):
- """
- Reset the history for a miner.
- """
- self.prediction_history[uid] = deque(maxlen=self.store_last_n_predictions)
- self.label_history[uid] = deque(maxlen=self.store_last_n_predictions)
- self.miner_hotkeys[uid] = miner_hotkey
-
- def update(self, uid: int, prediction: np.ndarray, label: int, miner_hotkey: str):
- """
- Update the miner prediction history
- Args:
- - prediction: numpy array of shape (3,) containing probabilities for [real, synthetic, semi-synthetic]
- - label: integer label (0 for real, 1 for synthetic, 2 for semi-synthetic)
- """
- if uid not in self.prediction_history or self.miner_hotkeys.get(uid) != miner_hotkey:
- self.reset_miner_history(uid, miner_hotkey)
- self.prediction_history[uid].append(np.array(prediction)) # store full probability vector
- self.label_history[uid].append(label)
-
- def get_metrics(self, uid: int, window: int = None):
- """
- Get the performance metrics for a miner based on their last n predictions
- """
- if uid not in self.prediction_history:
- return self._empty_metrics()
-
- recent_preds = list(self.prediction_history[uid])
- recent_labels = list(self.label_history[uid])
- if window is not None:
- window = min(window, len(recent_preds))
- recent_preds = recent_preds[-window:]
- recent_labels = recent_labels[-window:]
-
- pred_probs = np.array([p for p in recent_preds if not np.array_equal(p, -1)])
- labels = np.array([l for i, l in enumerate(recent_labels) if not np.array_equal(recent_preds[i], -1)])
-
- if len(labels) == 0 or len(pred_probs) == 0:
- return self._empty_metrics()
-
- try:
- predictions = np.argmax(pred_probs, axis=1)
- # multiclass MCC (real vs synthetic vs semi-synthetic)
- multi_class_mcc = matthews_corrcoef(labels, predictions)
- # binary MCC (real vs any synthetic)
- binary_labels = (labels > 0).astype(int)
- binary_preds = (predictions > 0).astype(int)
- binary_mcc = matthews_corrcoef(binary_labels, binary_preds)
- return {
- 'multi_class_mcc': multi_class_mcc,
- 'binary_mcc': binary_mcc
- }
- except Exception as e:
- bt.logging.warning(f'Error in reward computation: {e}')
- return self._empty_metrics()
-
- def _empty_metrics(self):
- """
- Return a dictionary of empty metrics
- """
- return {
- 'multi_class_mcc': 0,
- 'binary_mcc': 0
- }
-
- def get_prediction_count(self, uid: int) -> int:
- """
- Get the number of predictions made by a specific miner.
- """
- if uid not in self.prediction_history:
- return 0
- return len(self.prediction_history[uid])
\ No newline at end of file
diff --git a/bitmind/validator/proxy.py b/bitmind/validator/proxy.py
deleted file mode 100644
index 966ff1b4..00000000
--- a/bitmind/validator/proxy.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import json
-import os
-from datetime import date
-
-
-class ProxyCounter:
- def __init__(self, save_path):
- self.save_path = save_path
- if os.path.exists(save_path):
- try:
- self.proxy_logs = json.load(open(save_path))
- except Exception as e:
- print(f"Error loading proxy logs: {e}")
- self.proxy_logs = {}
- else:
- self.proxy_logs = {}
-
- def update(self, is_success):
- today = str(date.today())
- self.proxy_logs.setdefault(today, {"success": 0, "fail": 0})
- if is_success:
- self.proxy_logs[today]["success"] += 1
- else:
- self.proxy_logs[today]["fail"] += 1
-
- def save(self):
- json.dump(self.proxy_logs, open(self.save_path, "w"))
diff --git a/bitmind/validator/reward.py b/bitmind/validator/reward.py
deleted file mode 100644
index 630b7f7a..00000000
--- a/bitmind/validator/reward.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# developer: dubm
-# Copyright © 2023 BitMind
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-from typing import List, Dict, Tuple, Any
-import bittensor as bt
-import numpy as np
-
-
-def compute_penalty_multiplier(y_pred: np.ndarray) -> float:
- """
- Compute penalty for predictions outside valid range.
-
- Args:
- y_pred (np.ndarray): Predicted probabilities for each class, shape (3,)
-
- Returns:
- float: 0.0 if prediction is invalid, 1.0 if valid
- """
- sum_check = np.abs(np.sum(y_pred) - 1.0) < 1e-6
- range_check = np.all((y_pred >= 0.0) & (y_pred <= 1.0))
- return 1.0 if (sum_check and range_check) else 0.0
-
-
-def transform_reward(reward, pole=1.01):
- if reward == 0:
- return 0
- return 1 / (pole - np.array(reward))
-
-
-def get_rewards(
- label: int,
- responses: List[np.ndarray],
- uids: List[int],
- axons: List[bt.axon],
- challenge_modality: str,
- performance_trackers: Dict[str, Any]
-) -> Tuple[np.ndarray, List[Dict[str, Dict[str, float]]]]:
- """
- Calculate rewards for miner responses based on performance metrics.
-
- Args:
- label: The true label (0 for real, 1 for synthetic, 2 for semi-synthetic)
- responses: List of probability vectors from miners, each shape (3,)
- uids: List of miner UIDs
- axons: List of miner axons
- challenge_modality: Type of challenge ('video' or 'image')
- performance_trackers: Dict mapping modality to performance tracker
-
- Returns:
- Tuple containing:
- - np.ndarray: Array of rewards for each miner
- - List[Dict]: List of performance metrics for each miner
- """
- miner_rewards = []
- miner_metrics = []
-
- for axon, uid, pred_probs in zip(axons, uids, responses):
- miner_modality_rewards = {}
- miner_modality_metrics = {}
-
- for modality in ['image', 'video']:
- tracker = performance_trackers[modality]
- try:
- miner_hotkey = axon.hotkey
-
- if uid in tracker.miner_hotkeys and tracker.miner_hotkeys[uid] != miner_hotkey:
- bt.logging.info(f"Miner hotkey changed for UID {uid}. Resetting performance metrics.")
- tracker.reset_miner_history(uid, miner_hotkey)
-
- if modality == challenge_modality:
- tracker.update(uid, pred_probs, label, miner_hotkey)
-
- metrics = tracker.get_metrics(uid, window=100)
- reward = (0.75 * metrics['binary_mcc'] + 0.25 * metrics['multi_class_mcc'])
- reward *= compute_penalty_multiplier(pred_probs)
-
- miner_modality_rewards[modality] = reward
- miner_modality_metrics[modality] = metrics
-
- except Exception as e:
- bt.logging.error(f"Couldn't calculate reward for miner {uid}, prediction: {pred_probs}, label: {label}")
- bt.logging.exception(e)
- miner_rewards.append(0.0)
- continue
-
- total_reward = (
- 0.4 * miner_modality_rewards.get('video', 0.0) +
- 0.6 * miner_modality_rewards.get('image', 0.0)
- )
- miner_rewards.append(total_reward)
- miner_metrics.append(miner_modality_metrics)
-
- return np.array(miner_rewards), miner_metrics
diff --git a/bitmind/validator/scripts/__init__.py b/bitmind/validator/scripts/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/bitmind/validator/scripts/prune_wandb_cache.py b/bitmind/validator/scripts/prune_wandb_cache.py
deleted file mode 100644
index aa4296cb..00000000
--- a/bitmind/validator/scripts/prune_wandb_cache.py
+++ /dev/null
@@ -1,67 +0,0 @@
-#!/usr/bin/env python3
-"""W&B Cache Cleaning Script"""
-import os
-import sys
-import glob
-import shutil
-import time
-import argparse
-from datetime import datetime
-
-def clean_wandb_cache(wandb_dir, hours=1):
- """Cleans wandb runs except recent ones and latest-run."""
- if not os.path.exists(wandb_dir):
- print(f"W&B directory not found: {wandb_dir}")
- return
-
- run_dirs = [d for d in glob.glob(os.path.join(wandb_dir, "run-*")) if os.path.isdir(d)]
-
- if not run_dirs:
- print("No W&B runs found.")
- return
-
- # Keep recent runs
- current_time = time.time()
- cutoff_time = current_time - (hours * 3600)
- recent_runs = [d for d in run_dirs if os.path.getmtime(d) > cutoff_time]
-
- # Preserve latest-run target
- latest_run_link = os.path.join(wandb_dir, "latest-run")
- if os.path.exists(latest_run_link) and os.path.isdir(latest_run_link):
- try:
- latest_run_target = os.path.realpath(latest_run_link)
- if latest_run_target not in recent_runs and latest_run_target in run_dirs:
- recent_runs.append(latest_run_target)
- print(f"Preserving latest-run: {os.path.basename(latest_run_target)}")
- except Exception as e:
- print(f"Error with latest-run: {e}")
-
- print(f"Keeping {len(recent_runs)} runs (modified in last {hours} hours):")
- for run in recent_runs:
- print(f" - {os.path.basename(run)}")
-
- # Remove old runs
- runs_removed = 0
- space_freed = 0
-
- for run_dir in run_dirs:
- if run_dir not in recent_runs:
- try:
- dir_size = sum(os.path.getsize(os.path.join(dirpath, filename))
- for dirpath, _, filenames in os.walk(run_dir)
- for filename in filenames)
- space_freed += dir_size
- shutil.rmtree(run_dir)
- runs_removed += 1
- except Exception as e:
- print(f"Error removing {run_dir}: {e}")
-
- print(f"Cleaned {runs_removed} runs, freed {space_freed / (1024*1024):.2f} MB")
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Clean W&B cache directory")
- parser.add_argument("--dir", default="./wandb", help="W&B directory path (default: ./wandb)")
- parser.add_argument("--hours", type=int, default=1, help="Keep runs newer than this many hours (default: 1)")
- args = parser.parse_args()
-
- clean_wandb_cache(args.dir, args.hours)
diff --git a/bitmind/validator/scripts/run_cache_updater.py b/bitmind/validator/scripts/run_cache_updater.py
deleted file mode 100644
index 1747b3a6..00000000
--- a/bitmind/validator/scripts/run_cache_updater.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import os
-os.environ["CUDA_VISIBLE_DEVICES"] = ""
-
-import bittensor as bt
-import asyncio
-import argparse
-
-from bitmind.validator.cache.image_cache import ImageCache
-from bitmind.validator.cache.video_cache import VideoCache
-from bitmind.validator.scripts.util import load_validator_info, init_wandb_run
-from bitmind.validator.config import (
- IMAGE_DATASETS,
- VIDEO_DATASETS,
- IMAGE_CACHE_UPDATE_INTERVAL,
- VIDEO_CACHE_UPDATE_INTERVAL,
- IMAGE_PARQUET_CACHE_UPDATE_INTERVAL,
- VIDEO_ZIP_CACHE_UPDATE_INTERVAL,
- REAL_VIDEO_CACHE_DIR,
- REAL_IMAGE_CACHE_DIR,
- SYNTH_IMAGE_CACHE_DIR,
- SEMISYNTH_VIDEO_CACHE_DIR,
- SEMISYNTH_IMAGE_CACHE_DIR,
- MAX_COMPRESSED_GB,
- MAX_EXTRACTED_GB
-)
-
-
-async def main(args):
-
- if args.modality in ['all', 'image']:
- bt.logging.info("Starting real image cache updater")
- real_image_cache = ImageCache(
- cache_dir=REAL_IMAGE_CACHE_DIR,
- datasets=IMAGE_DATASETS['real'],
- parquet_update_interval=args.image_parquet_interval,
- image_update_interval=args.image_interval,
- num_parquets_per_dataset=5,
- num_images_per_source=100,
- max_extracted_size_gb=MAX_EXTRACTED_GB,
- max_compressed_size_gb=MAX_COMPRESSED_GB
- )
- real_image_cache.start_updater()
-
- bt.logging.info("Starting semisynthetic image cache updater")
- semisynth_image_cache = ImageCache(
- cache_dir=SEMISYNTH_IMAGE_CACHE_DIR,
- datasets=IMAGE_DATASETS['semisynthetic'],
- parquet_update_interval=args.image_parquet_interval,
- image_update_interval=args.image_interval,
- num_parquets_per_dataset=5,
- num_images_per_source=100,
- max_extracted_size_gb=MAX_EXTRACTED_GB,
- max_compressed_size_gb=MAX_COMPRESSED_GB
- )
- semisynth_image_cache.start_updater()
-
- bt.logging.info("Starting synthetic image cache updater")
- synth_image_cache = ImageCache(
- cache_dir=SYNTH_IMAGE_CACHE_DIR,
- datasets=IMAGE_DATASETS['synthetic'],
- parquet_update_interval=args.image_parquet_interval,
- image_update_interval=args.image_interval,
- num_parquets_per_dataset=5,
- num_images_per_source=100,
- max_extracted_size_gb=MAX_EXTRACTED_GB,
- max_compressed_size_gb=MAX_COMPRESSED_GB
- )
- synth_image_cache.start_updater()
-
- if args.modality in ['all', 'video']:
- bt.logging.info("Starting real video cache updater")
- real_video_cache = VideoCache(
- cache_dir=REAL_VIDEO_CACHE_DIR,
- datasets=VIDEO_DATASETS['real'],
- video_update_interval=args.video_interval,
- zip_update_interval=args.video_zip_interval,
- num_zips_per_dataset=2,
- num_videos_per_zip=100,
- max_extracted_size_gb=MAX_EXTRACTED_GB,
- max_compressed_size_gb=100,
- )
- real_video_cache.start_updater()
-
- bt.logging.info("Starting semisynthetic video cache updater")
- semisynth_video_cache = VideoCache(
- cache_dir=SEMISYNTH_VIDEO_CACHE_DIR,
- datasets=VIDEO_DATASETS['semisynthetic'],
- video_update_interval=args.video_interval,
- zip_update_interval=args.video_zip_interval,
- num_zips_per_dataset=2,
- num_videos_per_zip=100,
- max_extracted_size_gb=MAX_EXTRACTED_GB,
- max_compressed_size_gb=MAX_COMPRESSED_GB
- )
- semisynth_video_cache.start_updater()
-
- while True:
- bt.logging.info(f"Running cache updaters for: {args.modality}")
- await asyncio.sleep(600) # Status update every 10 minutes
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--modality', type=str, default='all', choices=['all', 'video', 'image'],
- help='Which cache updater(s) to run')
- parser.add_argument(
- '--image-interval', type=int, default=IMAGE_CACHE_UPDATE_INTERVAL,
- help='Update interval for images in hours')
- parser.add_argument(
- '--image-parquet-interval', type=int, default=IMAGE_PARQUET_CACHE_UPDATE_INTERVAL,
- help='Update interval for image parquet files in hours')
- parser.add_argument(
- '--video-interval', type=int, default=VIDEO_CACHE_UPDATE_INTERVAL,
- help='Update interval for videos in hours')
- parser.add_argument(
- '--video-zip-interval', type=int, default=VIDEO_ZIP_CACHE_UPDATE_INTERVAL,
- help='Update interval for video zip files in hours')
- args = parser.parse_args()
-
- bt.logging.set_info()
- init_wandb_run(run_base_name='cache-updater', **load_validator_info())
-
- try:
- asyncio.run(main(args))
- except KeyboardInterrupt:
- bt.logging.info("Shutting down cache updaters...")
diff --git a/bitmind/validator/scripts/run_data_generator.py b/bitmind/validator/scripts/run_data_generator.py
deleted file mode 100644
index 660200e5..00000000
--- a/bitmind/validator/scripts/run_data_generator.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import argparse
-import time
-
-import bittensor as bt
-
-from bitmind.validator.scripts.util import load_validator_info, init_wandb_run
-from bitmind.synthetic_data_generation import SyntheticDataGenerator
-from bitmind.validator.cache import ImageCache
-from bitmind.validator.config import (
- REAL_IMAGE_CACHE_DIR,
- SN34_CACHE_DIR,
- MODEL_NAMES,
- get_task
-)
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--image-cache-dir', type=str, default=REAL_IMAGE_CACHE_DIR,
- help='Directory containing real images to use as reference')
- parser.add_argument('--output-dir', type=str, default=SN34_CACHE_DIR,
- help='Directory to save generated data')
- parser.add_argument('--device', type=str, default='cuda',
- help='Device to run generation on (cuda/cpu)')
- parser.add_argument('--batch-size', type=int, default=3,
- help='Number of images to generate per batch')
- parser.add_argument('--model', type=str, default=None, choices=MODEL_NAMES,
- help='Specific model to test. If not specified, uses random models')
- args = parser.parse_args()
-
- if args.model:
- bt.logging.info(f"Using model {args.model} ({get_task(args.model)})")
- else:
- bt.logging.info(f"No model selected.")
-
- bt.logging.set_info()
- init_wandb_run(run_base_name='data-generator', **load_validator_info())
-
- image_cache = ImageCache(args.image_cache_dir)
- while True:
- if image_cache._extracted_cache_empty():
- bt.logging.info("SyntheticDataGenerator waiting for real image cache to populate")
- time.sleep(5)
- continue
- bt.logging.info("Image cache was populated! Proceeding to data generation")
- break
-
- sdg = SyntheticDataGenerator(
- prompt_type='annotation',
- use_random_model=args.model is None,
- model_name=args.model,
- device=args.device,
- image_cache=image_cache,
- output_dir=args.output_dir)
-
- bt.logging.info("Starting data generator service")
- sdg.batch_generate(batch_size=1)
-
- while True:
- try:
- sdg.batch_generate(batch_size=args.batch_size)
- except Exception as e:
- bt.logging.error(f"Error in batch generation: {str(e)}")
- time.sleep(5)
diff --git a/bitmind/validator/scripts/util.py b/bitmind/validator/scripts/util.py
deleted file mode 100644
index 9beb727c..00000000
--- a/bitmind/validator/scripts/util.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import time
-import yaml
-
-import wandb
-import bittensor as bt
-
-import bitmind
-from bitmind.validator.config import (
- WANDB_ENTITY,
- TESTNET_WANDB_PROJECT,
- MAINNET_WANDB_PROJECT,
- MAINNET_UID,
- VALIDATOR_INFO_PATH
-)
-
-def load_validator_info(max_wait: int = 300):
- start_time = time.time()
- while True:
- try:
- with open(VALIDATOR_INFO_PATH, 'r') as f:
- validator_info = yaml.safe_load(f)
- bt.logging.info(f"Loaded validator info from {VALIDATOR_INFO_PATH}")
- return validator_info
- except FileNotFoundError:
- if time.time() - start_time > max_wait:
- bt.logging.error(f"Validator info not found at {VALIDATOR_INFO_PATH} after waiting 5 minutes. Exiting.")
- exit(1)
- bt.logging.info(f"Waiting for validator info at {VALIDATOR_INFO_PATH}")
- time.sleep(3)
- continue
- except yaml.YAMLError:
- bt.logging.error(f"Could not parse validator info at {VALIDATOR_INFO_PATH}")
- validator_info = {
- 'uid': 'ParseError',
- 'hotkey': 'ParseError',
- 'full_path': 'ParseError',
- 'netuid': TESTNET_WANDB_PROJECT
- }
- return validator_info
-
-
-def init_wandb_run(run_base_name: str, uid: str, hotkey: str, netuid: int, full_path: str) -> None:
- """
- Initialize a Weights & Biases run for tracking the validator.
-
- Args:
- vali_uid: The validator's uid
- vali_hotkey: The validator's hotkey address
- netuid: The network ID (mainnet or testnet)
- vali_full_path: Validator's bittensor directory
-
- Returns:
- None
- """
- run_name = f'{run_base_name}-{uid}-{bitmind.__version__}'
-
- config = {
- 'run_name': run_name,
- 'uid': uid,
- 'hotkey': hotkey,
- 'version': bitmind.__version__
- }
-
- wandb_project = TESTNET_WANDB_PROJECT
- if netuid == MAINNET_UID:
- wandb_project = MAINNET_WANDB_PROJECT
-
- # Initialize the wandb run for the single project
- bt.logging.info(f"Initializing W&B run for '{WANDB_ENTITY}/{wandb_project}'")
- try:
- return wandb.init(
- name=run_name,
- project=wandb_project,
- entity=WANDB_ENTITY,
- config=config,
- dir=full_path,
- reinit=True
- )
- except wandb.UsageError as e:
- bt.logging.warning(e)
- bt.logging.warning("Did you run wandb login?")
- return
\ No newline at end of file
diff --git a/bitmind/validator/verify_models.py b/bitmind/validator/verify_models.py
deleted file mode 100644
index ff80c9e0..00000000
--- a/bitmind/validator/verify_models.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import os
-from bitmind.synthetic_data_generation import SyntheticDataGenerator
-from bitmind.validator.config import MODEL_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL
-import bittensor as bt
-
-
-def is_model_cached(model_name):
- """
- Check if the specified model is cached by looking for its directory in the Hugging Face cache.
-
- Args:
- model_name (str): The name of the model to check.
-
- Returns:
- bool: True if the model is cached, False otherwise.
- """
- cache_dir = os.path.expanduser('~/.cache/huggingface/')
- # Format the directory name correctly by replacing each slash with double dashes
- model_dir = f"models--{model_name.replace('/', '--')}"
-
- # Construct the full path to where the model directory should be
- model_path = os.path.join(cache_dir, model_dir)
-
- # Check if the model directory exists
- if os.path.isdir(model_path):
- print(f"{model_name} is in HF cache. Skipping....")
- return True
- else:
- print(f"{model_name} is not cached. Downloading....")
- return False
-
-
-def main():
- """
- Main function to verify and download validator models.
-
- This function checks if the required models are cached and downloads them if necessary.
- It also initializes and loads diffusers for uncached models.
- """
- bt.logging.info("Verifying validator model downloads....")
- synthetic_image_generator = SyntheticDataGenerator(
- prompt_type='annotation',
- image_cache='test',
- use_random_model=True
- )
-
- # Check and load annotation and moderation models if not cached
- if not is_model_cached(IMAGE_ANNOTATION_MODEL) or not is_model_cached(TEXT_MODERATION_MODEL):
- synthetic_image_generator.prompt_generator.load_models()
- synthetic_image_generator.prompt_generator.clear_gpu()
-
- # Initialize and load diffusers if not cached
- for model_name in MODEL_NAMES:
- if not is_model_cached(model_name):
- synthetic_image_generator = SyntheticDataGenerator(
- prompt_type=None,
- use_random_model=False,
- model_name=model_name
- )
- synthetic_image_generator.load_model(model_name)
- synthetic_image_generator.clear_gpu()
-
-
-if __name__ == "__main__":
- main()
diff --git a/bitmind/validator/video_utils.py b/bitmind/validator/video_utils.py
deleted file mode 100644
index 3cb4137d..00000000
--- a/bitmind/validator/video_utils.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import tempfile
-from pathlib import Path
-from typing import Optional, BinaryIO, List, Union
-
-import bittensor as bt
-import ffmpeg
-import numpy as np
-from moviepy import VideoFileClip
-from PIL import Image
-
-from .cache.util import seconds_to_str
-
-
-def video_to_pil(video_path: Union[str, Path]) -> List[Image.Image]:
- """Load video file and convert it to a list of PIL images.
-
- Args:
- video_path: Path to the input video file.
-
- Returns:
- List of PIL Image objects representing each frame of the video.
- """
- clip = VideoFileClip(str(video_path))
- frames = [Image.fromarray(np.array(frame)) for frame in clip.iter_frames()]
- clip.close()
- return frames
-
-
-def clip_video(
- video_path: str,
- start: int,
- num_seconds: int
-) -> Optional[BinaryIO]:
- """Extract a clip from a video file.
-
- Args:
- video_path: Path to the input video file.
- start: Start time in seconds.
- num_seconds: Duration of the clip in seconds.
-
- Returns:
- A temporary file object containing the clipped video,
- or None if the operation fails.
-
- Raises:
- ffmpeg.Error: If FFmpeg encounters an error during processing.
- """
- temp_fileobj = tempfile.NamedTemporaryFile(suffix=".mp4")
- try:
- (
- ffmpeg
- .input(video_path, ss=seconds_to_str(start), t=str(num_seconds))
- .output(temp_fileobj.name, vf='fps=1')
- .overwrite_output()
- .run(capture_stderr=True)
- )
- return temp_fileobj
- except ffmpeg.Error as e:
- bt.logging.error(f"FFmpeg error: {e.stderr.decode()}")
- raise
-
-
-def get_video_duration(filename: str) -> int:
- """Get the duration of a video file in seconds.
-
- Args:
- filename: Path to the video file.
-
- Returns:
- Duration of the video in seconds.
-
- Raises:
- KeyError: If video stream information cannot be found.
- """
- metadata = ffmpeg.probe(filename)
- video_stream = next(
- (stream for stream in metadata['streams']
- if stream['codec_type'] == 'video'),
- None
- )
- if not video_stream:
- raise KeyError("No video stream found in the file")
- return int(float(video_stream['duration']))
-
-
-def copy_audio(video_path: str) -> BinaryIO:
- """Extract the audio stream from a video file.
-
- Args:
- video_path: Path to the input video file.
-
- Returns:
- A temporary file object containing the extracted audio stream.
-
- Raises:
- ffmpeg.Error: If FFmpeg encounters an error during processing.
- """
- temp_audiofile = tempfile.NamedTemporaryFile(suffix=".aac")
- (
- ffmpeg
- .input(video_path)
- .output(temp_audiofile.name, vn=None, acodec='copy')
- .overwrite_output()
- .run(quiet=True)
- )
- return temp_audiofile
diff --git a/bitmind/wandb_utils.py b/bitmind/wandb_utils.py
new file mode 100644
index 00000000..ad56bf5c
--- /dev/null
+++ b/bitmind/wandb_utils.py
@@ -0,0 +1,324 @@
+import glob
+import json
+import os
+import shutil
+import time
+import uuid
+
+import bittensor as bt
+import wandb
+
+
+class WandbLogger:
+ def __init__(self, config, validator_uid, validator_hotkey):
+ """
+ Initialize the WandB logger using a single project for both media and results.
+
+ Args:
+ config: Bittensor config object
+ validator_uid: Validator UID
+ validator_hotkey: Validator hotkey for signing
+ """
+ self.config = config
+ self.wandb_dir = config.neuron.full_path
+
+ self.uid = validator_uid
+ self.hotkey = validator_hotkey
+ self.run = None
+
+ self.session_artifacts = set()
+
+ clean_wandb_cache(self.wandb_dir)
+
+ def start_new_run(self):
+ """
+ Ensure validator run is active and return it.
+
+ Returns:
+ wandb.Run: The active wandb run
+ """
+ clean_wandb_cache(self.wandb_dir)
+ if self.run is None or not wandb.run:
+ self.run = init_wandb(
+ self.config, "validator", self.uid, self.hotkey, self.wandb_dir
+ )
+ else:
+ self.run.finish()
+ self.run = init_wandb(
+ self.config, "validator", self.uid, self.hotkey, self.wandb_dir
+ )
+ return self.run
+
+ def _ensure_run(self):
+ """
+ Ensure validator run is active and return it.
+
+ Returns:
+ wandb.Run: The active wandb run
+ """
+ if self.run is None or not wandb.run:
+ self.run = init_wandb(
+ self.config, "validator", self.uid, self.hotkey, self.wandb_dir
+ )
+ return self.run
+
+ def _check_media_exists(self, filepath):
+ """
+ Check if a media file has already been logged to WandB using only UUID lookup.
+
+ Args:
+ filepath: Path to the media file to check
+
+ Returns:
+ tuple: (exists (bool), media_uuid (str or None))
+ """
+ metadata_path = os.path.splitext(filepath)[0] + ".json"
+ if os.path.exists(metadata_path):
+ try:
+ with open(metadata_path, "r") as f:
+ metadata = json.load(f)
+
+ media_uuid = metadata.get("media_uuid")
+ if media_uuid:
+ api = wandb.Api()
+ project = f"subnet-{self.config.netuid}-validator"
+ artifact_path = f"{self.config.wandb.entity}/{project}/media-{media_uuid}:latest"
+ try:
+ artifact = api.artifact(artifact_path)
+ return True, media_uuid
+ except wandb.errors.CommError:
+ pass
+ except (json.JSONDecodeError, IOError) as e:
+ bt.logging.warning(f"Error reading metadata file: {e}")
+
+ return False, None
+
+ def _maybe_log_media(self, media_path, metadata_path):
+ """
+ Log media as a WandB Artifact, with simple UUID-based deduplication.
+ Only logs media that hasn't been logged yet.
+ Only logs synthetic, locally generated media.
+
+ Args:
+ media_path: Path to the media file
+ metadata_path: Path to the metadata JSON file
+
+ Returns:
+ str or None: UUID assigned to the media if logged, None if not logged
+ """
+ exists, existing_uuid = self._check_media_exists(media_path)
+ if exists:
+ bt.logging.info(f"Media already exists in WandB with UUID: {existing_uuid}")
+ return existing_uuid
+
+ run = self._ensure_run()
+
+ metadata = {}
+ if os.path.exists(metadata_path):
+ try:
+ with open(metadata_path, "r") as f:
+ metadata = json.load(f)
+ except json.JSONDecodeError:
+ bt.logging.warning(f"Error parsing metadata file: {metadata_path}")
+
+ # Only create uuids for and log locally generated synthetic media.
+ # All other media are already stored on Huggingface
+ if not metadata.get("model_name"):
+ return None
+
+ if not metadata.get("media_uuid"):
+ metadata["media_uuid"] = str(uuid.uuid4())
+ try:
+ with open(metadata_path, "w") as f:
+ json.dump(metadata, f, indent=2)
+ except IOError as e:
+ bt.logging.warning(f"Error writing metadata file: {e}")
+
+ media_uuid = metadata["media_uuid"]
+ media_artifact = wandb.Artifact(
+ name=f"media-{media_uuid}", type="media", metadata=metadata
+ )
+
+ extension = os.path.splitext(media_path)[1]
+ media_artifact.add_file(media_path, f"media{extension}")
+
+ run.log_artifact(media_artifact)
+
+ if "media_uuids" not in list(run.summary.keys()):
+ run.summary["media_uuids"] = []
+
+ media_uuids = run.summary.get("media_uuids", [])
+ if media_uuid not in media_uuids:
+ run.summary["media_uuids"] = media_uuids + [media_uuid]
+ # run.summary.update()
+
+ bt.logging.info(f"Logged media file to WandB with UUID: {media_uuid}")
+ return media_uuid
+
+ def _log_challenge_results(self, challenge_results, media_uuids):
+ """
+ Log challenge results with reference to media artifact.
+
+ Args:
+ challenge_results: Dictionary of challenge results
+ media_uuids: List of UUIDs of the associated media
+ """
+ run = self._ensure_run()
+ log_data = {
+ "results": challenge_results,
+ "media_uuids": media_uuids,
+ }
+
+ run.log(log_data)
+ bt.logging.info(f"Logged challenge results with media UUIDs: {media_uuids}")
+
+ def log(self, media_sample, challenge_results):
+ """
+ Combined method to log both media and challenge results.
+
+ Args:
+ media_sample: Dictionary containing media paths and metadata paths
+ challenge_results: Dictionary of challenge results
+
+ Returns:
+ list: List of UUIDs assigned to the logged media
+ """
+
+ # Step 1: Log media if applicable
+ # Only locally generated synthetic media are logged
+ media_path = media_sample.get("path")
+ metadata_path = media_sample.get("metadata_path")
+ if media_path and metadata_path:
+ media_uuids = [self._maybe_log_media(media_path, metadata_path)]
+ else:
+ media_uuids = []
+ for i in range(1):
+ media_path = media_sample.get(f"sample_{i}", {}).get("path")
+ metadata_path = media_sample.get(f"sample_{i}", {}).get("metadata_path")
+ if media_path and metadata_path:
+ media_uuids.append(self._maybe_log_media(media_path, metadata_path))
+
+ # Step 2: Log challenge results with reference to logged media uuid if available
+ self._log_challenge_results(challenge_results, media_uuids)
+ return media_uuids
+
+ def finish(self):
+ """Finish the current run if it exists."""
+ if self.run and wandb.run:
+ self.run.finish()
+ self.run = None
+
+
+def init_wandb(
+ config: bt.config, process: str, uid: int, hotkey: bt.Keypair, wandb_dir: str = None
+) -> wandb.run:
+ """
+ Initialize a Weights & Biases run.
+
+ Args:
+ config: Bittensor config object
+ process: Valid options are 'validator', 'data-generator', 'media-store'
+ uid: Validator uid
+ hotkey: Bittensor keypair for signing the run
+ wandb_dir: Optional directory for wandb files
+
+ Returns:
+ wandb.run: The initialized wandb run, or None if initialization fails
+ """
+ from bitmind import __version__
+
+ project = f"subnet-{config.netuid}-{process}"
+ run_name = f"{process}-{uid}-{__version__}"
+ config.run_name = run_name
+ config.uid = uid
+ config.hotkey = hotkey.ss58_address
+ config.version = __version__
+
+ bt.logging.info(f"Initializing wandb run in '{config.wandb.entity}/{project}'")
+
+ try:
+ run = wandb.init(
+ name=run_name,
+ project=project,
+ entity=config.entity,
+ config=config,
+ dir=wandb_dir if wandb_dir else config.full_path,
+ reinit=True,
+ )
+ except wandb.UsageError as e:
+ bt.logging.warning(e)
+ bt.logging.warning("Did you run wandb login?")
+ return
+
+ # sign the run to prove it's from this hotkey
+ signature = hotkey.sign(run.id.encode()).hex()
+ config.signature = signature
+ wandb.config.update(config, allow_val_change=True)
+
+ bt.logging.success(f"Started wandb run {run_name}")
+ return run
+
+
+def clean_wandb_cache(wandb_dir, hours=1):
+ """
+ Cleans wandb runs except recent ones and latest-run.
+
+ Args:
+ wandb_dir: Directory containing wandb run files
+ hours: Number of hours to keep runs for (default: 1)
+ """
+ if not os.path.exists(wandb_dir):
+ bt.logging.warning(f"W&B directory not found: {wandb_dir}")
+ return
+
+ run_dirs = [
+ d for d in glob.glob(os.path.join(wandb_dir, "run-*")) if os.path.isdir(d)
+ ]
+
+ if not run_dirs:
+ bt.logging.info("No W&B runs found.")
+ return
+
+ # Keep recent runs
+ current_time = time.time()
+ cutoff_time = current_time - (hours * 3600)
+ recent_runs = [d for d in run_dirs if os.path.getmtime(d) > cutoff_time]
+
+ # Preserve latest-run target
+ latest_run_link = os.path.join(wandb_dir, "latest-run")
+ if os.path.exists(latest_run_link) and os.path.isdir(latest_run_link):
+ try:
+ latest_run_target = os.path.realpath(latest_run_link)
+ if latest_run_target not in recent_runs and latest_run_target in run_dirs:
+ recent_runs.append(latest_run_target)
+ bt.logging.debug(
+ f"Preserving latest-run: {os.path.basename(latest_run_target)}"
+ )
+ except Exception as e:
+ bt.logging.warning(f"Error with latest-run: {e}")
+
+ bt.logging.info(
+ f"Keeping {len(recent_runs)} runs (modified in last {hours} hours):"
+ )
+ for run in recent_runs:
+ bt.logging.info(f" - {os.path.basename(run)}")
+
+ runs_removed = 0
+ space_freed = 0
+ for run_dir in run_dirs:
+ if run_dir not in recent_runs:
+ try:
+ dir_size = sum(
+ os.path.getsize(os.path.join(dirpath, filename))
+ for dirpath, _, filenames in os.walk(run_dir)
+ for filename in filenames
+ )
+ space_freed += dir_size
+ shutil.rmtree(run_dir)
+ runs_removed += 1
+ except Exception as e:
+ bt.logging.warning(f"Error removing {run_dir}: {e}")
+
+ bt.logging.info(
+ f"Cleaned {runs_removed} runs, freed {space_freed / (1024*1024):.2f} MB"
+ )
diff --git a/contrib/CODE_REVIEW_DOCS.md b/contrib/CODE_REVIEW_DOCS.md
deleted file mode 100644
index 9909606a..00000000
--- a/contrib/CODE_REVIEW_DOCS.md
+++ /dev/null
@@ -1,72 +0,0 @@
-# Code Review
-### Conceptual Review
-
-A review can be a conceptual review, where the reviewer leaves a comment
- * `Concept (N)ACK`, meaning "I do (not) agree with the general goal of this pull
- request",
- * `Approach (N)ACK`, meaning `Concept ACK`, but "I do (not) agree with the
- approach of this change".
-
-A `NACK` needs to include a rationale why the change is not worthwhile.
-NACKs without accompanying reasoning may be disregarded.
-After conceptual agreement on the change, code review can be provided. A review
-begins with `ACK BRANCH_COMMIT`, where `BRANCH_COMMIT` is the top of the PR
-branch, followed by a description of how the reviewer did the review. The
-following language is used within pull request comments:
-
- - "I have tested the code", involving change-specific manual testing in
- addition to running the unit, functional, or fuzz tests, and in case it is
- not obvious how the manual testing was done, it should be described;
- - "I have not tested the code, but I have reviewed it and it looks
- OK, I agree it can be merged";
- - A "nit" refers to a trivial, often non-blocking issue.
-
-### Code Review
-Project maintainers reserve the right to weigh the opinions of peer reviewers
-using common sense judgement and may also weigh based on merit. Reviewers that
-have demonstrated a deeper commitment and understanding of the project over time
-or who have clear domain expertise may naturally have more weight, as one would
-expect in all walks of life.
-
-Where a patch set affects consensus-critical code, the bar will be much
-higher in terms of discussion and peer review requirements, keeping in mind that
-mistakes could be very costly to the wider community. This includes refactoring
-of consensus-critical code.
-
-Where a patch set proposes to change the Bittensor consensus, it must have been
-discussed extensively on the discord server and other channels, be accompanied by a widely
-discussed BIP and have a generally widely perceived technical consensus of being
-a worthwhile change based on the judgement of the maintainers.
-
-### Finding Reviewers
-
-As most reviewers are themselves developers with their own projects, the review
-process can be quite lengthy, and some amount of patience is required. If you find
-that you've been waiting for a pull request to be given attention for several
-months, there may be a number of reasons for this, some of which you can do something
-about:
-
- - It may be because of a feature freeze due to an upcoming release. During this time,
- only bug fixes are taken into consideration. If your pull request is a new feature,
- it will not be prioritized until after the release. Wait for the release.
- - It may be because the changes you are suggesting do not appeal to people. Rather than
- nits and critique, which require effort and means they care enough to spend time on your
- contribution, thundering silence is a good sign of widespread (mild) dislike of a given change
- (because people don't assume *others* won't actually like the proposal). Don't take
- that personally, though! Instead, take another critical look at what you are suggesting
- and see if it: changes too much, is too broad, doesn't adhere to the
- [developer notes](DEVELOPMENT_WORKFLOW.md), is dangerous or insecure, is messily written, etc.
- Identify and address any of the issues you find. Then ask e.g. on IRC if someone could give
- their opinion on the concept itself.
- - It may be because your code is too complex for all but a few people, and those people
- may not have realized your pull request even exists. A great way to find people who
- are qualified and care about the code you are touching is the
- [Git Blame feature](https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/tracking-changes-in-a-file). Simply
- look up who last modified the code you are changing and see if you can find
- them and give them a nudge. Don't be incessant about the nudging, though.
- - Finally, if all else fails, ask on IRC or elsewhere for someone to give your pull request
- a look. If you think you've been waiting for an unreasonably long time (say,
- more than a month) for no particular reason (a few lines changed, etc.),
- this is totally fine. Try to return the favor when someone else is asking
- for feedback on their code, and the universe balances out.
- - Remember that the best thing you can do while waiting is give review to others!
\ No newline at end of file
diff --git a/contrib/CONTRIBUTING.md b/contrib/CONTRIBUTING.md
deleted file mode 100644
index ba33ce3c..00000000
--- a/contrib/CONTRIBUTING.md
+++ /dev/null
@@ -1,213 +0,0 @@
-# Contributing to Bittensor Subnet Development
-
-The following is a set of guidelines for contributing to the Bittensor ecosystem. These are **HIGHLY RECOMMENDED** guidelines, but not hard-and-fast rules. Use your best judgment, and feel free to propose changes to this document in a pull request.
-
-## Table Of Contents
-1. [How Can I Contribute?](#how-can-i-contribute)
- 1. [Communication Channels](#communication-channels)
- 1. [Code Contribution General Guideline](#code-contribution-general-guidelines)
- 1. [Pull Request Philosophy](#pull-request-philosophy)
- 1. [Pull Request Process](#pull-request-process)
- 1. [Addressing Feedback](#addressing-feedback)
- 1. [Squashing Commits](#squashing-commits)
- 1. [Refactoring](#refactoring)
- 1. [Peer Review](#peer-review)
- 1. [Suggesting Features](#suggesting-enhancements-and-features)
-
-
-## How Can I Contribute?
-TODO(developer): Define your desired contribution procedure.
-
-## Communication Channels
-TODO(developer): Place your communication channels here
-
-> Please follow the Bittensor Subnet [style guide](./STYLE.md) regardless of your contribution type.
-
-Here is a high-level summary:
-- Code consistency is crucial; adhere to established programming language conventions.
-- Use `black` to format your Python code; it ensures readability and consistency.
-- Write concise Git commit messages; summarize changes in ~50 characters.
-- Follow these six commit rules:
- - Atomic Commits: Focus on one task or fix per commit.
- - Subject and Body Separation: Use a blank line to separate the subject from the body.
- - Subject Line Length: Keep it under 50 characters for readability.
- - Imperative Mood: Write subject line as if giving a command or instruction.
- - Body Text Width: Wrap text manually at 72 characters.
- - Body Content: Explain what changed and why, not how.
-- Make use of your commit messages to simplify project understanding and maintenance.
-
-> For clear examples of each of the commit rules, see the style guide's [rules](./STYLE.md#the-six-rules-of-a-great-commit) section.
-
-### Code Contribution General Guidelines
-
-> Review the Bittensor Subnet [style guide](./STYLE.md) and [development workflow](./DEVELOPMENT_WORKFLOW.md) before contributing.
-
-
-#### Pull Request Philosophy
-
-Patchsets and enhancements should always be focused. A pull request could add a feature, fix a bug, or refactor code, but it should not contain a mixture of these. Please also avoid 'super' pull requests which attempt to do too much, are overly large, or overly complex as this makes review difficult.
-
-Specifically, pull requests must adhere to the following criteria:
-- Contain fewer than 50 files. PRs with more than 50 files will be closed.
-- If a PR introduces a new feature, it *must* include corresponding tests.
-- Other PRs (bug fixes, refactoring, etc.) should ideally also have tests, as they provide proof of concept and prevent regression.
-- Categorize your PR properly by using GitHub labels. This aids in the review process by informing reviewers about the type of change at a glance.
-- Make sure your code includes adequate comments. These should explain why certain decisions were made and how your changes work.
-- If your changes are extensive, consider breaking your PR into smaller, related PRs. This makes your contributions easier to understand and review.
-- Be active in the discussion about your PR. Respond promptly to comments and questions to help reviewers understand your changes and speed up the acceptance process.
-
-Generally, all pull requests must:
-
- - Have a clear use case, fix a demonstrable bug or serve the greater good of the project (e.g. refactoring for modularisation).
- - Be well peer-reviewed.
- - Follow code style guidelines.
- - Not break the existing test suite.
- - Where bugs are fixed, where possible, there should be unit tests demonstrating the bug and also proving the fix.
- - Change relevant comments and documentation when behaviour of code changes.
-
-#### Pull Request Process
-
-Please follow these steps to have your contribution considered by the maintainers:
-
-*Before* creating the PR:
-1. Read the [development workflow](./DEVELOPMENT_WORKFLOW.md) defined for this repository to understand our workflow.
-2. Ensure your PR meets the criteria stated in the 'Pull Request Philosophy' section.
-3. Include relevant tests for any fixed bugs or new features as stated in the [testing guide](./TESTING.md).
-4. Ensure your commit messages are clear and concise. Include the issue number if applicable.
-5. If you have multiple commits, rebase them into a single commit using `git rebase -i`.
-6. Explain what your changes do and why you think they should be merged in the PR description consistent with the [style guide](./STYLE.md).
-
-*After* creating the PR:
-1. Verify that all [status checks](https://help.github.com/articles/about-status-checks/) are passing after you submit your pull request.
-2. Label your PR using GitHub's labeling feature. The labels help categorize the PR and streamline the review process.
-3. Document your code with comments that provide a clear understanding of your changes. Explain any non-obvious parts of your code or design decisions you've made.
-4. If your PR has extensive changes, consider splitting it into smaller, related PRs. This reduces the cognitive load on the reviewers and speeds up the review process.
-
-Please be responsive and participate in the discussion on your PR! This aids in clarifying any confusion or concerns and leads to quicker resolution and merging of your PR.
-
-> Note: If your changes are not ready for merge but you want feedback, create a draft pull request.
-
-Following these criteria will aid in quicker review and potential merging of your PR.
-While the prerequisites above must be satisfied prior to having your pull request reviewed, the reviewer(s) may ask you to complete additional design work, tests, or other changes before your pull request can be ultimately accepted.
-
-When you are ready to submit your changes, create a pull request:
-
-> **Always** follow the [style guide](./STYLE.md) and [development workflow](./DEVELOPMENT_WORKFLOW.md) before submitting pull requests.
-
-After you submit a pull request, it will be reviewed by the maintainers. They may ask you to make changes. Please respond to any comments and push your changes as a new commit.
-
-> Note: Be sure to merge the latest from "upstream" before making a pull request:
-
-```bash
-git remote add upstream https://github.com/opentensor/bittensor.git # TODO(developer): replace with your repo URL
-git fetch upstream
-git merge upstream/
-git push origin
-```
-
-#### Addressing Feedback
-
-After submitting your pull request, expect comments and reviews from other contributors. You can add more commits to your pull request by committing them locally and pushing to your fork.
-
-You are expected to reply to any review comments before your pull request is merged. You may update the code or reject the feedback if you do not agree with it, but you should express so in a reply. If there is outstanding feedback and you are not actively working on it, your pull request may be closed.
-
-#### Squashing Commits
-
-If your pull request contains fixup commits (commits that change the same line of code repeatedly) or too fine-grained commits, you may be asked to [squash](https://git-scm.com/docs/git-rebase#_interactive_mode) your commits before it will be reviewed. The basic squashing workflow is shown below.
-
- git checkout your_branch_name
- git rebase -i HEAD~n
- # n is normally the number of commits in the pull request.
- # Set commits (except the one in the first line) from 'pick' to 'squash', save and quit.
- # On the next screen, edit/refine commit messages.
- # Save and quit.
- git push -f # (force push to GitHub)
-
-Please update the resulting commit message, if needed. It should read as a coherent message. In most cases, this means not just listing the interim commits.
-
-If your change contains a merge commit, the above workflow may not work and you will need to remove the merge commit first. See the next section for details on how to rebase.
-
-Please refrain from creating several pull requests for the same change. Use the pull request that is already open (or was created earlier) to amend changes. This preserves the discussion and review that happened earlier for the respective change set.
-
-The length of time required for peer review is unpredictable and will vary from pull request to pull request.
-
-#### Refactoring
-
-Refactoring is a necessary part of any software project's evolution. The following guidelines cover refactoring pull requests for the project.
-
-There are three categories of refactoring: code-only moves, code style fixes, and code refactoring. In general, refactoring pull requests should not mix these three kinds of activities in order to make refactoring pull requests easy to review and uncontroversial. In all cases, refactoring PRs must not change the behaviour of code within the pull request (bugs must be preserved as is).
-
-Project maintainers aim for a quick turnaround on refactoring pull requests, so where possible keep them short, uncomplex and easy to verify.
-
-Pull requests that refactor the code should not be made by new contributors. It requires a certain level of experience to know where the code belongs to and to understand the full ramification (including rebase effort of open pull requests). Trivial pull requests or pull requests that refactor the code with no clear benefits may be immediately closed by the maintainers to reduce unnecessary workload on reviewing.
-
-#### Peer Review
-
-Anyone may participate in peer review which is expressed by comments in the pull request. Typically reviewers will review the code for obvious errors, as well as test out the patch set and opine on the technical merits of the patch. Project maintainers take into account the peer review when determining if there is consensus to merge a pull request (remember that discussions may have taken place elsewhere, not just on GitHub). The following language is used within pull-request comments:
-
-- ACK means "I have tested the code and I agree it should be merged";
-- NACK means "I disagree this should be merged", and must be accompanied by sound technical justification. NACKs without accompanying reasoning may be disregarded;
-- utACK means "I have not tested the code, but I have reviewed it and it looks OK, I agree it can be merged";
-- Concept ACK means "I agree in the general principle of this pull request";
-- Nit refers to trivial, often non-blocking issues.
-
-Reviewers should include the commit(s) they have reviewed in their comments. This can be done by copying the commit SHA1 hash.
-
-A pull request that changes consensus-critical code is considerably more involved than a pull request that adds a feature to the wallet, for example. Such patches must be reviewed and thoroughly tested by several reviewers who are knowledgeable about the changed subsystems. Where new features are proposed, it is helpful for reviewers to try out the patch set on a test network and indicate that they have done so in their review. Project maintainers will take this into consideration when merging changes.
-
-For a more detailed description of the review process, see the [Code Review Guidelines](CODE_REVIEW_DOCS.md).
-
-> **Note:** If you find a **Closed** issue that seems like it is the same thing that you're experiencing, open a new issue and include a link to the original issue in the body of your new one.
-
-#### How Do I Submit A (Good) Bug Report?
-
-Please track bugs as GitHub issues.
-
-Explain the problem and include additional details to help maintainers reproduce the problem:
-
-* **Use a clear and descriptive title** for the issue to identify the problem.
-* **Describe the exact steps which reproduce the problem** in as many details as possible. For example, start by explaining how you started the application, e.g. which command exactly you used in the terminal, or how you started Bittensor otherwise. When listing steps, **don't just say what you did, but explain how you did it**. For example, if you ran with a set of custom configs, explain if you used a config file or command line arguments.
-* **Provide specific examples to demonstrate the steps**. Include links to files or GitHub projects, or copy/pasteable snippets, which you use in those examples. If you're providing snippets in the issue, use [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines).
-* **Describe the behavior you observed after following the steps** and point out what exactly is the problem with that behavior.
-* **Explain which behavior you expected to see instead and why.**
-* **Include screenshots and animated GIFs** which show you following the described steps and clearly demonstrate the problem. You can use [this tool](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [this tool](https://github.com/colinkeenan/silentcast) or [this tool](https://github.com/GNOME/byzanz) on Linux.
-* **If you're reporting that Bittensor crashed**, include a crash report with a stack trace from the operating system. On macOS, the crash report will be available in `Console.app` under "Diagnostic and usage information" > "User diagnostic reports". Include the crash report in the issue in a [code block](https://help.github.com/articles/markdown-basics/#multiple-lines), a [file attachment](https://help.github.com/articles/file-attachments-on-issues-and-pull-requests/), or put it in a [gist](https://gist.github.com/) and provide link to that gist.
-* **If the problem is related to performance or memory**, include a CPU profile capture with your report, if you're using a GPU then include a GPU profile capture as well. Look into the [PyTorch Profiler](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) to look at memory usage of your model.
-* **If the problem wasn't triggered by a specific action**, describe what you were doing before the problem happened and share more information using the guidelines below.
-
-Provide more context by answering these questions:
-
-* **Did the problem start happening recently** (e.g. after updating to a new version) or was this always a problem?
-* If the problem started happening recently, **can you reproduce the problem in an older version of Bittensor?**
-* **Can you reliably reproduce the issue?** If not, provide details about how often the problem happens and under which conditions it normally happens.
-
-Include details about your configuration and environment:
-
-* **Which version of Bittensor Subnet are you using?**
-* **What commit hash are you on?** You can get the exact commit hash by checking `git log` and pasting the full commit hash.
-* **What's the name and version of the OS you're using**?
-* **Are you running Bittensor Subnet in a virtual machine?** If so, which VM software are you using and which operating systems and versions are used for the host and the guest?
-* **Are you running Bittensor Subnet in a dockerized container?** If so, have you made sure that your docker container contains your latest changes and is up to date with Master branch?
-
-### Suggesting Enhancements and Features
-
-This section guides you through submitting an enhancement suggestion, including completely new features and minor improvements to existing functionality. Following these guidelines helps maintainers and the community understand your suggestion :pencil: and find related suggestions :mag_right:.
-
-When you are creating an enhancement suggestion, please [include as many details as possible](#how-do-i-submit-a-good-enhancement-suggestion). Fill in [the template](https://bit.ly/atom-behavior-pr), including the steps that you imagine you would take if the feature you're requesting existed.
-
-#### Before Submitting An Enhancement Suggestion
-
-* **Check the [debugging guide](./DEBUGGING.md).** for tips — you might discover that the enhancement is already available. Most importantly, check if you're using the latest version of the project first.
-
-#### How Submit A (Good) Feature Suggestion
-
-* **Use a clear and descriptive title** for the issue to identify the problem.
-* **Provide a step-by-step description of the suggested enhancement** in as many details as possible.
-* **Provide specific examples to demonstrate the steps**. Include copy/pasteable snippets which you use in those examples, as [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines).
-* **Describe the current behavior** and **explain which behavior you expected to see instead** and why.
-* **Include screenshots and animated GIFs** which help you demonstrate the steps or point out the part of the project which the suggestion is related to. You can use [this tool](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [this tool](https://github.com/colinkeenan/silentcast) or [this tool](https://github.com/GNOME/byzanz) on Linux.
-* **Explain why this enhancement would be useful** to most users.
-* **List some other text editors or applications where this enhancement exists.**
-* **Specify the name and version of the OS you're using.**
-
-Thank you for considering contributing to Bittensor! Any help is greatly appreciated along this journey to incentivize open and permissionless intelligence.
diff --git a/contrib/DEVELOPMENT_WORKFLOW.md b/contrib/DEVELOPMENT_WORKFLOW.md
deleted file mode 100644
index 13bb07b2..00000000
--- a/contrib/DEVELOPMENT_WORKFLOW.md
+++ /dev/null
@@ -1,165 +0,0 @@
-# Bittensor Subnet Development Workflow
-
-This is a highly advisable workflow to follow to keep your subtensor project organized and foster ease of contribution.
-
-## Table of contents
-
-- [Bittensor Subnet Development Workflow](#bittensor-subnet-development-workflow)
- - [Main Branches](#main-branches)
- - [Development Model](#development-model)
- - [Feature Branches](#feature-branches)
- - [Release Branches](#release-branches)
- - [Hotfix Branches](#hotfix-branches)
- - [Git Operations](#git-operations)
- - [Creating a Feature Branch](#creating-a-feature-branch)
- - [Merging Feature Branch into Staging](#merging-feature-branch-into-staging)
- - [Creating a Release Branch](#creating-a-release-branch)
- - [Finishing a Release Branch](#finishing-a-release-branch)
- - [Creating a Hotfix Branch](#creating-a-hotfix-branch)
- - [Finishing a Hotfix Branch](#finishing-a-hotfix-branch)
- - [Continuous Integration (CI) and Continuous Deployment (CD)](#continuous-integration-ci-and-continuous-deployment-cd)
- - [Versioning and Release Notes](#versioning-and-release-notes)
- - [Pending Tasks](#pending-tasks)
-
-## Main Branches
-
-Bittensor's codebase consists of two main branches: **main** and **staging**.
-
-**main**
-- This is Bittensor's live production branch, which should only be updated by the core development team. This branch is protected, so refrain from pushing or merging into it unless authorized.
-
-**staging**
-- This branch is continuously updated and is where you propose and merge changes. It's essentially Bittensor's active development branch.
-
-## Development Model
-
-### Feature Branches
-
-- Branch off from: `staging`
-- Merge back into: `staging`
-- Naming convention: `feature//`
-
-Feature branches are used to develop new features for upcoming or future releases. They exist as long as the feature is in development, but will eventually be merged into `staging` or discarded. Always delete your feature branch after merging to avoid unnecessary clutter.
-
-### Release Branches
-
-- Branch off from: `staging`
-- Merge back into: `staging` and then `main`
-- Naming convention: `release///`
-
-Release branches support the preparation of a new production release, allowing for minor bug fixes and preparation of metadata (version number, configuration, etc). All new features should be merged into `staging` and wait for the next big release.
-
-### Hotfix Branches
-
-General workflow:
-
-- Branch off from: `main` or `staging`
-- Merge back into: `staging` then `main`
-- Naming convention: `hotfix///`
-
-Hotfix branches are meant for quick fixes in the production environment. When a critical bug in a production version must be resolved immediately, a hotfix branch is created.
-
-## Git Operations
-
-#### Create a feature branch
-
-1. Branch from the **staging** branch.
- 1. Command: `git checkout -b feature/my-feature staging`
-
-> Rebase frequently with the updated staging branch so you do not face big conflicts before submitting your pull request. Remember, syncing your changes with other developers could also help you avoid big conflicts.
-
-#### Merge feature branch into staging
-
-In other words, integrate your changes into a branch that will be tested and prepared for release.
-
-1. Switch branch to staging: `git checkout staging`
-2. Merging feature branch into staging: `git merge --no-ff feature/my-feature`
-3. Pushing changes to staging: `git push origin staging`
-4. Delete feature branch: `git branch -d feature/my-feature` (alternatively, this can be navigated on the GitHub web UI)
-
-This operation is done by Github when merging a PR.
-
-So, what you have to keep in mind is:
-- Open the PR against the `staging` branch.
-- After merging a PR you should delete your feature branch. This will be strictly enforced.
-
-#### Creating a release branch
-
-1. Create branch from staging: `git checkout -b release/3.4.0/descriptive-message/creator's_name staging`
-2. Updating version with major or minor: `./scripts/update_version.sh major|minor`
-3. Commit file changes with new version: `git commit -a -m "Updated version to 3.4.0"`
-
-
-#### Finishing a Release Branch
-
-This involves releasing stable code and generating a new version for bittensor.
-
-1. Switch branch to main: `git checkout main`
-2. Merge release branch into main: `git merge --no-ff release/3.4.0/optional-descriptive-message`
-3. Tag changeset: `git tag -a v3.4.0 -m "Releasing v3.4.0: some comment about it"`
-4. Push changes to main: `git push origin main`
-5. Push tags to origin: `git push origin --tags`
-
-To keep the changes made in the __release__ branch, we need to merge those back into `staging`:
-
-- Switch branch to staging: `git checkout staging`.
-- Merging release branch into staging: `git merge --no-ff release/3.4.0/optional-descriptive-message`
-
-This step may well lead to a merge conflict (probably even, since we have changed the version number). If so, fix it and commit.
-
-
-#### Creating a hotfix branch
-1. Create branch from main: `git checkout -b hotfix/3.3.4/descriptive-message/creator's-name main`
-2. Update patch version: `./scripts/update_version.sh patch`
-3. Commit file changes with new version: `git commit -a -m "Updated version to 3.3.4"`
-4. Fix the bug and commit the fix: `git commit -m "Fixed critical production issue X"`
-
-#### Finishing a Hotfix Branch
-
-Finishing a hotfix branch involves merging the bugfix into both `main` and `staging`.
-
-1. Switch branch to main: `git checkout main`
-2. Merge hotfix into main: `git merge --no-ff hotfix/3.3.4/optional-descriptive-message`
-3. Tag new version: `git tag -a v3.3.4 -m "Releasing v3.3.4: descriptive comment about the hotfix"`
-4. Push changes to main: `git push origin main`
-5. Push tags to origin: `git push origin --tags`
-6. Switch branch to staging: `git checkout staging`
-7. Merge hotfix into staging: `git merge --no-ff hotfix/3.3.4/descriptive-message/creator's-name`
-8. Push changes to origin/staging: `git push origin staging`
-9. Delete hotfix branch: `git branch -d hotfix/3.3.4/optional-descriptive-message`
-
-The one exception to the rule here is that, **when a release branch currently exists, the hotfix changes need to be merged into that release branch, instead of** `staging`. Back-merging the bugfix into the __release__ branch will eventually result in the bugfix being merged into `develop` too, when the release branch is finished. (If work in develop immediately requires this bugfix and cannot wait for the release branch to be finished, you may safely merge the bugfix into develop now already as well.)
-
-Finally, we remove the temporary branch:
-
-- `git branch -d hotfix/3.3.4/optional-descriptive-message`
-## Continuous Integration (CI) and Continuous Deployment (CD)
-
-Continuous Integration (CI) is a software development practice where members of a team integrate their work frequently. Each integration is verified by an automated build and test process to detect integration errors as quickly as possible.
-
-Continuous Deployment (CD) is a software engineering approach in which software functionalities are delivered frequently through automated deployments.
-
-- **CircleCI job**: Create jobs in CircleCI to automate the merging of staging into main and release version (needed to release code) and building and testing Bittensor (needed to merge PRs).
-
-> It is highly recommended to set up your own circleci pipeline with your subnet
-
-## Versioning and Release Notes
-
-Semantic versioning helps keep track of the different versions of the software. When code is merged into main, generate a new version.
-
-Release notes provide documentation for each version released to the users, highlighting the new features, improvements, and bug fixes. When merged into main, generate GitHub release and release notes.
-
-## Pending Tasks
-
-Follow these steps when you are contributing to the bittensor subnet:
-
-- Determine if main and staging are different
-- Determine what is in staging that is not merged yet
- - Document not released developments
- - When merged into staging, generate information about what's merged into staging but not released.
- - When merged into main, generate GitHub release and release notes.
-- CircleCI jobs
- - Merge staging into main and release version (needed to release code)
- - Build and Test Bittensor (needed to merge PRs)
-
-This document can be improved as the Bittensor project continues to develop and change.
diff --git a/contrib/STYLE.md b/contrib/STYLE.md
deleted file mode 100644
index b7ac755f..00000000
--- a/contrib/STYLE.md
+++ /dev/null
@@ -1,348 +0,0 @@
-# Style Guide
-
-A project’s long-term success rests (among other things) on its maintainability, and a maintainer has few tools more powerful than his or her project’s log. It’s worth taking the time to learn how to care for one properly. What may be a hassle at first soon becomes habit, and eventually a source of pride and productivity for all involved.
-
-Most programming languages have well-established conventions as to what constitutes idiomatic style, i.e. naming, formatting and so on. There are variations on these conventions, of course, but most developers agree that picking one and sticking to it is far better than the chaos that ensues when everybody does their own thing.
-
-# Table of Contents
-1. [Code Style](#code-style)
-2. [Naming Conventions](#naming-conventions)
-3. [Git Commit Style](#git-commit-style)
-4. [The Six Rules of a Great Commit](#the-six-rules-of-a-great-commit)
- - [1. Atomic Commits](#1-atomic-commits)
- - [2. Separate Subject from Body with a Blank Line](#2-separate-subject-from-body-with-a-blank-line)
- - [3. Limit the Subject Line to 50 Characters](#3-limit-the-subject-line-to-50-characters)
- - [4. Use the Imperative Mood in the Subject Line](#4-use-the-imperative-mood-in-the-subject-line)
- - [5. Wrap the Body at 72 Characters](#5-wrap-the-body-at-72-characters)
- - [6. Use the Body to Explain What and Why vs. How](#6-use-the-body-to-explain-what-and-why-vs-how)
-5. [Tools Worth Mentioning](#tools-worth-mentioning)
- - [Using `--fixup`](#using---fixup)
- - [Interactive Rebase](#interactive-rebase)
-6. [Pull Request and Squashing Commits Caveats](#pull-request-and-squashing-commits-caveats)
-
-
-### Code style
-
-#### General Style
-Python's official style guide is PEP 8, which provides conventions for writing code for the main Python distribution. Here are some key points:
-
-- `Indentation:` Use 4 spaces per indentation level.
-
-- `Line Length:` Limit all lines to a maximum of 79 characters.
-
-- `Blank Lines:` Surround top-level function and class definitions with two blank lines. Method definitions inside a class are surrounded by a single blank line.
-
-- `Imports:` Imports should usually be on separate lines and should be grouped in the following order:
-
- - Standard library imports.
- - Related third party imports.
- - Local application/library specific imports.
-- `Whitespace:` Avoid extraneous whitespace in the following situations:
-
- - Immediately inside parentheses, brackets or braces.
- - Immediately before a comma, semicolon, or colon.
- - Immediately before the open parenthesis that starts the argument list of a function call.
-- `Comments:` Comments should be complete sentences and should be used to clarify code and are not a substitute for poorly written code.
-
-#### For Python
-
-- `List Comprehensions:` Use list comprehensions for concise and readable creation of lists.
-
-- `Generators:` Use generators when dealing with large amounts of data to save memory.
-
-- `Context Managers:` Use context managers (with statement) for resource management.
-
-- `String Formatting:` Use f-strings for formatting strings in Python 3.6 and above.
-
-- `Error Handling:` Use exceptions for error handling whenever possible.
-
-#### More details
-
-Use `black` to format your python code before commiting for consistency across such a large pool of contributors. Black's code [style](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html#code-style) ensures consistent and opinionated code formatting. It automatically formats your Python code according to the Black style guide, enhancing code readability and maintainability.
-
-Key Features of Black:
-
- Consistency: Black enforces a single, consistent coding style across your project, eliminating style debates and allowing developers to focus on code logic.
-
- Readability: By applying a standard formatting style, Black improves code readability, making it easier to understand and collaborate on projects.
-
- Automation: Black automates the code formatting process, saving time and effort. It eliminates the need for manual formatting and reduces the likelihood of inconsistencies.
-
-### Naming Conventions
-
-- `Classes:` Class names should normally use the CapWords Convention.
-- `Functions and Variables:` Function names should be lowercase, with words separated by underscores as necessary to improve readability. Variable names follow the same convention as function names.
-
-- `Constants:` Constants are usually defined on a module level and written in all capital letters with underscores separating words.
-
-- `Non-public Methods and Instance Variables:` Use a single leading underscore (_). This is a weak "internal use" indicator.
-
-- `Strongly "private" methods and variables:` Use a double leading underscore (__). This triggers name mangling in Python.
-
-
-### Git commit style
-
-Here’s a model Git commit message when contributing:
-```
-Summarize changes in around 50 characters or less
-
-More detailed explanatory text, if necessary. Wrap it to about 72
-characters or so. In some contexts, the first line is treated as the
-subject of the commit and the rest of the text as the body. The
-blank line separating the summary from the body is critical (unless
-you omit the body entirely); various tools like `log`, `shortlog`
-and `rebase` can get confused if you run the two together.
-
-Explain the problem that this commit is solving. Focus on why you
-are making this change as opposed to how (the code explains that).
-Are there side effects or other unintuitive consequences of this
-change? Here's the place to explain them.
-
-Further paragraphs come after blank lines.
-
- - Bullet points are okay, too
-
- - Typically a hyphen or asterisk is used for the bullet, preceded
- by a single space, with blank lines in between, but conventions
- vary here
-
-If you use an issue tracker, put references to them at the bottom,
-like this:
-
-Resolves: #123
-See also: #456, #789
-```
-
-
-## The six rules of a great commit.
-
-#### 1. Atomic Commits
-An “atomic” change revolves around one task or one fix.
-
-Atomic Approach
- - Commit each fix or task as a separate change
- - Only commit when a block of work is complete
- - Commit each layout change separately
- - Joint commit for layout file, code behind file, and additional resources
-
-Benefits
-
-- Easy to roll back without affecting other changes
-- Easy to make other changes on the fly
-- Easy to merge features to other branches
-
-#### Avoid trivial commit messages
-
-Commit messages like "fix", "fix2", or "fix3" don't provide any context or clear understanding of what changes the commit introduces. Here are some examples of good vs. bad commit messages:
-
-**Bad Commit Message:**
-
- $ git commit -m "fix"
-
-**Good Commit Message:**
-
- $ git commit -m "Fix typo in README file"
-
-> **Caveat**: When working with new features, an atomic commit will often consist of multiple files, since a layout file, code behind file, and additional resources may have been added/modified. You don’t want to commit all of these separately, because if you had to roll back the application to a state before the feature was added, it would involve multiple commit entries, and that can get confusing
-
-#### 2. Separate subject from body with a blank line
-
-Not every commit requires both a subject and a body. Sometimes a single line is fine, especially when the change is so simple that no further context is necessary.
-
-For example:
-
- Fix typo in introduction to user guide
-
-Nothing more need be said; if the reader wonders what the typo was, she can simply take a look at the change itself, i.e. use git show or git diff or git log -p.
-
-If you’re committing something like this at the command line, it’s easy to use the -m option to git commit:
-
- $ git commit -m"Fix typo in introduction to user guide"
-
-However, when a commit merits a bit of explanation and context, you need to write a body. For example:
-
- Derezz the master control program
-
- MCP turned out to be evil and had become intent on world domination.
- This commit throws Tron's disc into MCP (causing its deresolution)
- and turns it back into a chess game.
-
-Commit messages with bodies are not so easy to write with the -m option. You’re better off writing the message in a proper text editor. [See Pro Git](https://git-scm.com/book/en/v2/Customizing-Git-Git-Configuration).
-
-In any case, the separation of subject from body pays off when browsing the log. Here’s the full log entry:
-
- $ git log
- commit 42e769bdf4894310333942ffc5a15151222a87be
- Author: Kevin Flynn
- Date: Fri Jan 01 00:00:00 1982 -0200
-
- Derezz the master control program
-
- MCP turned out to be evil and had become intent on world domination.
- This commit throws Tron's disc into MCP (causing its deresolution)
- and turns it back into a chess game.
-
-
-#### 3. Limit the subject line to 50 characters
-50 characters is not a hard limit, just a rule of thumb. Keeping subject lines at this length ensures that they are readable, and forces the author to think for a moment about the most concise way to explain what’s going on.
-
-GitHub’s UI is fully aware of these conventions. It will warn you if you go past the 50 character limit. Git will truncate any subject line longer than 72 characters with an ellipsis, thus keeping it to 50 is best practice.
-
-#### 4. Use the imperative mood in the subject line
-Imperative mood just means “spoken or written as if giving a command or instruction”. A few examples:
-
- Clean your room
- Close the door
- Take out the trash
-
-Each of the seven rules you’re reading about right now are written in the imperative (“Wrap the body at 72 characters”, etc.).
-
-The imperative can sound a little rude; that’s why we don’t often use it. But it’s perfect for Git commit subject lines. One reason for this is that Git itself uses the imperative whenever it creates a commit on your behalf.
-
-For example, the default message created when using git merge reads:
-
- Merge branch 'myfeature'
-
-And when using git revert:
-
- Revert "Add the thing with the stuff"
-
- This reverts commit cc87791524aedd593cff5a74532befe7ab69ce9d.
-
-Or when clicking the “Merge” button on a GitHub pull request:
-
- Merge pull request #123 from someuser/somebranch
-
-So when you write your commit messages in the imperative, you’re following Git’s own built-in conventions. For example:
-
- Refactor subsystem X for readability
- Update getting started documentation
- Remove deprecated methods
- Release version 1.0.0
-
-Writing this way can be a little awkward at first. We’re more used to speaking in the indicative mood, which is all about reporting facts. That’s why commit messages often end up reading like this:
-
- Fixed bug with Y
- Changing behavior of X
-
-And sometimes commit messages get written as a description of their contents:
-
- More fixes for broken stuff
- Sweet new API methods
-
-To remove any confusion, here’s a simple rule to get it right every time.
-
-**A properly formed Git commit subject line should always be able to complete the following sentence:**
-
- If applied, this commit will
-
-For example:
-
- If applied, this commit will refactor subsystem X for readability
- If applied, this commit will update getting started documentation
- If applied, this commit will remove deprecated methods
- If applied, this commit will release version 1.0.0
- If applied, this commit will merge pull request #123 from user/branch
-
-#### 5. Wrap the body at 72 characters
-Git never wraps text automatically. When you write the body of a commit message, you must mind its right margin, and wrap text manually.
-
-The recommendation is to do this at 72 characters, so that Git has plenty of room to indent text while still keeping everything under 80 characters overall.
-
-A good text editor can help here. It’s easy to configure Vim, for example, to wrap text at 72 characters when you’re writing a Git commit.
-
-#### 6. Use the body to explain what and why vs. how
-This [commit](https://github.com/bitcoin/bitcoin/commit/eb0b56b19017ab5c16c745e6da39c53126924ed6) from Bitcoin Core is a great example of explaining what changed and why:
-
-```
-commit eb0b56b19017ab5c16c745e6da39c53126924ed6
-Author: Pieter Wuille
-Date: Fri Aug 1 22:57:55 2014 +0200
-
- Simplify serialize.h's exception handling
-
- Remove the 'state' and 'exceptmask' from serialize.h's stream
- implementations, as well as related methods.
-
- As exceptmask always included 'failbit', and setstate was always
- called with bits = failbit, all it did was immediately raise an
- exception. Get rid of those variables, and replace the setstate
- with direct exception throwing (which also removes some dead
- code).
-
- As a result, good() is never reached after a failure (there are
- only 2 calls, one of which is in tests), and can just be replaced
- by !eof().
-
- fail(), clear(n) and exceptions() are just never called. Delete
- them.
-```
-
-Take a look at the [full diff](https://github.com/bitcoin/bitcoin/commit/eb0b56b19017ab5c16c745e6da39c53126924ed6) and just think how much time the author is saving fellow and future committers by taking the time to provide this context here and now. If he didn’t, it would probably be lost forever.
-
-In most cases, you can leave out details about how a change has been made. Code is generally self-explanatory in this regard (and if the code is so complex that it needs to be explained in prose, that’s what source comments are for). Just focus on making clear the reasons why you made the change in the first place—the way things worked before the change (and what was wrong with that), the way they work now, and why you decided to solve it the way you did.
-
-The future maintainer that thanks you may be yourself!
-
-
-
-#### Tools worth mentioning
-
-##### Using `--fixup`
-
-If you've made a commit and then realize you've missed something or made a minor mistake, you can use the `--fixup` option.
-
-For example, suppose you've made a commit with a hash `9fceb02`. Later, you realize you've left a debug statement in your code. Instead of making a new commit titled "remove debug statement" or "fix", you can do the following:
-
- $ git commit --fixup 9fceb02
-
-This will create a new commit to fix the issue, with a message like "fixup! The original commit message".
-
-##### Interactive Rebase
-
-Interactive rebase, or `rebase -i`, can be used to squash these fixup commits into the original commits they're fixing, which cleans up your commit history. You can use the `autosquash` option to automatically squash any commits marked as "fixup" into their target commits.
-
-For example:
-
- $ git rebase -i --autosquash HEAD~5
-
-This command starts an interactive rebase for the last 5 commits (`HEAD~5`). Any commits marked as "fixup" will be automatically moved to squash with their target commits.
-
-The benefit of using `--fixup` and interactive rebase is that it keeps your commit history clean and readable. It groups fixes with the commits they are related to, rather than having a separate "fix" commit that might not make sense to other developers (or even to you) in the future.
-
-
----
-
-#### Pull Request and Squashing Commits Caveats
-
-While atomic commits are great for development and for understanding the changes within the branch, the commit history can get messy when merging to the main branch. To keep a cleaner and more understandable commit history in our main branch, we encourage squashing all the commits of a PR into one when merging.
-
-This single commit should provide an overview of the changes that the PR introduced. It should follow the guidelines for atomic commits (an atomic commit is complete, self-contained, and understandable) but on the scale of the entire feature, task, or fix that the PR addresses. This approach combines the benefits of atomic commits during development with a clean commit history in our main branch.
-
-Here is how you can squash commits:
-
-```bash
-git rebase -i HEAD~n
-```
-
-where `n` is the number of commits to squash. After running the command, replace `pick` with `squash` for the commits you want to squash into the previous commit. This will combine the commits and allow you to write a new commit message.
-
-In this context, an atomic commit message could look like:
-
-```
-Add feature X
-
-This commit introduces feature X which does A, B, and C. It adds
-new files for layout, updates the code behind the file, and introduces
-new resources. This change is important because it allows users to
-perform task Y more efficiently.
-
-It includes:
-- Creation of new layout file
-- Updates in the code-behind file
-- Addition of new resources
-
-Resolves: #123
-```
-
-In your PRs, remember to detail what the PR is introducing or fixing. This will be helpful for reviewers to understand the context and the reason behind the changes.
diff --git a/create_video_dataset_example.sh b/create_video_dataset_example.sh
deleted file mode 100755
index 02144562..00000000
--- a/create_video_dataset_example.sh
+++ /dev/null
@@ -1,14 +0,0 @@
-#!/bin/bash
-
-# --input_dir is a directory of mp4 files
-# --frames_dir is where the extracted png frames will be stored
-# --dataset_dir is where the huggingface dataset (containing paths to frames) will be stored
-# once your dataset is created, you can add its local path to base_miner/config.py for training
-python base_miner/datasets/create_video_dataset.py --input_dir ~/.cache/sn34/real/video \
- --frames_dir ~/.cache/sn34/train_frames \
- --dataset_dir ~/.cache/sn34/train_dataset/real_frames \
- --num_videos 500 \
- --frame_rate 5 \
- --max_frames 24 \
- --dataset_name real_frames \
- --overwrite
diff --git a/docs/Code_of_Conduct.md b/docs/Code_of_Conduct.md
deleted file mode 100644
index 3a9ef96d..00000000
--- a/docs/Code_of_Conduct.md
+++ /dev/null
@@ -1,50 +0,0 @@
-
-
-
-
-# BitMind Subnet Code of Conduct
-
-
-
-Welcome to the BitMind Subnet Code of Conduct! This document emphasizes our commitment to a safe, respectful, and abuse-free environment. It outlines our expectations for behavior and the consequences for unacceptable actions, including abusing the subnet’s capabilities and other community members.
-
-## Purpose
-
-This Code of Conduct aims to protect the BitMind Subnet and its users from abusive behaviors, ensuring a secure and productive environment for all community members regardless of background.
-
-## Expected Behavior
-
-- **Respect for all participants** in speech and actions.
-- **Collaboration before conflict.**
-- **Prompt reporting** of any concerns related to abuse of the technology or community standards.
-
-## Unacceptable Behavior
-
-Unacceptable behavior includes, but is not limited to:
-- **Misuse of the subnet's capabilities** for harmful or illegal activities.
-- **Harassment** in any form, including derogatory comments and unwelcome sexual attention.
-- **Intimidation** or threats.
-- **Disruption** of community events or discussions.
-
-## Consequences of Unacceptable Behavior
-
-Participants asked to stop any harmful activity are expected to comply immediately. If a participant engages in abusive behavior, the BitMind Subnet community organizers may take any actions deemed appropriate, up to and including expulsion from the community and escalation to relevant authorities.
-
-## Reporting Guidelines
-
-If you witness or are subjected to unacceptable behavior, or have any other concerns, please promptly contact a community organizer or report the incident through our [Discord](https://discord.gg/bitmind).
-
-## Scope
-
-This Code of Conduct applies to all community interactions, including forums, social media, and public events related to the BitMind Subnet.
-
-## Contact Info
-
-For any issues, please contact us via [Discord](https://discord.gg/bitmind) or email the project maintainers directly.
-
-## License
-
-This Code of Conduct is licensed under the MIT License, which allows reuse of this document provided attribution is given and this license notice is retained.
-
-Please note that the datasets used by the BitMind Subnet are sourced under their own respective licenses. For more information on dataset licensing and usage guidelines, please refer to the [Datasets 📊 README ](Datasets.md) in our repository.
-
diff --git a/docs/Contributor_Guide.md b/docs/Contributor_Guide.md
deleted file mode 100644
index 63aa52ea..00000000
--- a/docs/Contributor_Guide.md
+++ /dev/null
@@ -1,68 +0,0 @@
-# Contributor Guide
-
-## Welcome to the BitMind Subnet Contributor Community!
-
-We're excited to have you interested in contributing to the BitMind Subnet. This guide aims to provide all the information you need to start contributing effectively to our project. Whether you're fixing bugs, adding features, or improving documentation, your help is welcome!
-
-### How to Contribute
-
-#### 1. Get Started
-Before you start contributing, make sure to go through our [Setup Guide 🔧](docs/Setup.md) to get your development environment ready. Also, familiarize yourself with our [Project Structure and Terminology 📖](docs/Glossary.md) to understand the layout and terminology used throughout the project.
-
-#### 2. Find an Issue
-Browse through our [GitHub Issues](https://github.com/bitmind-ai/bitmind-subnet/issues) to find tasks that need help.
-
-#### 3. Fork and Clone the Repository
-- Fork the repository by clicking the "Fork" button on the top right of the page. Then, clone your fork to your local machine:
-- Clone your fork to your local machine:
-```bash
-git clone https://github.com/your-username/bitmind-subnet.git
-cd bitmind-subnet
-```
-- Set the original repository as your 'upstream' remote:
-```bash
-git remote add upstream https://github.com/bitmind-ai/bitmind-subnet.git
-```
-#### 4. Sync Your Fork.
-Before you start making changes, sync your fork with the upstream repository to ensure you have the latest updates:
-```bash
-git fetch upstream
-git checkout main
-git merge upstream/main
-```
-
-#### 5. Create a Branch
-Create a new branch to work on. It's best to name the branch something descriptive:
-```
-git checkout -b feature/add-new-detection-model
-```
-
-#### 6. Make Your Changes
-Make changes to the codebase or documentation. Ensure you follow our coding standards (PEP-8) and write tests if you are adding or modifying functionality.
-
-#### 7. Commit Your Changes
-Keep your commits as small as possible and focused on a single aspect of improvement. This approach makes it easier to review and manage:
-```bash
-git add .
-git commit -m "Add a detailed commit message describing the change"
-```
-
-#### 8. Push Your Changes
-Push your changes to your fork:
-```bash
-git push origin feature/add-new-detection-model
-```
-
-#### 9. Submit a Pull Request (PR)
-Go to the Pull Requests tab in the original repository and click "New pull request". Compare branches and make sure you are proposing changes from your branch to the main repository's main branch. Provide a concise description of the changes and reference any related issues.
-
-#### 10. Participate in the Code Review Process
-Once your PR is submitted, other contributors and maintainers will review your changes. Engage in discussions and make any requested adjustments. Your contributions will be merged once they are approved.
-
-#### Code of Conduct
-We expect all contributors to adhere to our [Code of Conduct 📜](Code_of_Conduct.md), ensuring respect and productive collaboration. Please read Code of Conduct to understand the expectations for behavior.
-
-#### Need Help?
-If you have any questions or need further guidance, don't hesitate to ask for help in our Discord community. We're here to make your contribution process as smooth as possible!
-
-Thank you for contributing to the BitMind Subnet! We appreciate your effort to help us improve and extend our capabilities in detecting AI-generated media.
\ No newline at end of file
diff --git a/docs/Datasets.md b/docs/Datasets.md
deleted file mode 100644
index a96516b6..00000000
--- a/docs/Datasets.md
+++ /dev/null
@@ -1,40 +0,0 @@
-# Datasets README
-
-This document provides an overview of the datasets used within the BitMind Subnet, including descriptions and links to their sources and access points on Hugging Face. Each dataset is under its own license, and users are encouraged to review these licenses before use.
-
-Details on decentralized data storage access points, which will serve as another way to facilitate the utilization of our datasets, will be provided soon. Stay tuned for updates on how to access these resources efficiently and securely.
-
-## Dataset Categories
-
-### Third-Party Real Image Datasets
-
-These datasets consist of authentic images sourced from various real-world scenarios. They are necessary for training our models to recognize genuine content.
-
-#### Dataset Name
-**Description**: Brief description of what the dataset includes and its purpose.
-**Hugging Face Link**: [Dataset on Hugging Face](#)
-**Original Source**: [Dataset Source](#)
-
-### Third-Party Synthetic Image Datasets
-
-These datasets contain images that are artificially created to simulate different imaging conditions and scenarios, aiding in model training against synthetic manipulations.
-
-#### Dataset Name
-**Description**: Brief description of what the dataset includes and its purpose.
-**Hugging Face Link**: [Dataset on Hugging Face](#)
-**Original Source**: [Dataset Source](#)
-
-### Synthetic Datasets Generated via Image-to-Text Annotation Models
-
-These datasets are created by annotating a third-party real image dataset using an image-to-text annotation model like BLIP-2, followed by generating corresponding images through a diffusion model. This approach ensures that the synthetic data mirrors the distribution of our real image datasets 1 to 1, providing a balanced training ground for our models.
-
-#### Dataset Name
-**Description**: Brief description of how the dataset was created and its purpose.
-**Hugging Face Link**: [Dataset on Hugging Face](#)
-**Original Source**: [Dataset Source](#)
-
-## Usage Guidelines
-
-Please ensure to adhere to the licensing agreements specified for each dataset. These licenses dictate how the datasets can be used, shared, and modified. For detailed licensing information, refer to the respective dataset links provided.
-
-For any further questions or clarifications, please contact the project maintainers or visit our [community Discord](https://discord.gg/bitmind).
diff --git a/docs/Glossary.md b/docs/Glossary.md
deleted file mode 100644
index 2f38afc7..00000000
--- a/docs/Glossary.md
+++ /dev/null
@@ -1,46 +0,0 @@
-# Project Structure and Terminology
-
-## Table of Contents
-
-1. [Overview and Terminology 📖](#overview-and-terminology)
-2. [Notable Directories 📁](#notable-directories)
-3. [Key Files and Descriptions 🗂️](#key-files-and-descriptions)
- - [bitmind/base/ 🔧](#bitmindbase)
- - [bitmind/validator/ 🛡️](#bitmindvalidator)
- - [bitmind/miner/ ⛏️](#bitmindminer)
-4. [Datasets 📊](#datasets)
-5. [Additional Tools 🧰](#additional-tools)
-
-### Overview and Terminology
-
-Before diving into the specifics of the directory structure and key components, let's familiarize ourselves with the essential terms used throughout this project. Understanding these terms is important for navigating and contributing to the BitMind Subnet effectively. For a more detailed explanation of the terminology, please refer to [Bittensor Building Blocks](https://docs.bittensor.com/learn/bittensor-building-blocks).
-
-- **Synapse**: Acts as a communication bridge between axons (servers) and dendrites (clients), facilitating data flow and processing.
-- **Neuron**: A fundamental unit that includes both an axon and a dendrite, enabling full participation in the network operations.
-
-### Notable Directories
-
-- **bitmind/**: This directory contains the specific implementations of Bittensor operations, which include the key components such as miners, validators, and neurons. This code is used both by validators/miners as well as the base_miner training/eval code.
- - **base/**: Houses base classes for miner, validator, and neuron functionalities, each inheriting from the broader Bittensor framework.
-
-### Key Files and Descriptions
-
-#### bitmind/base/
-- **miner.py**: Responsible for loading models and weights, and handling predictions on images.
-- **validator.py**: Implements core functionality for generating challenges for miners, scoring responses, and setting weights.
-- **neuron.py**: A class that inherits from the base miner class provided by Bittensor, incorporating both axon and dendrite functionalities.
-
-#### bitmind/validator/
-- **forward.py**: Manages image processing and synapse operations using `ImageSynapse` for 256x256 images. Includes logic for challenge issuance and reward updates based on performance.
-- **proxy.py**: Temporarily unused; intended for handling frontend requests.
-
-#### bitmind/miner/
-- **predict.py**: Handles image transformation and the execution of model inference.
-
-### Datasets
-
-- **real_fake_dataset**: Utilized by the base miner for training, distinguishing between real and fake images.
-
-### Additional Tools
-
-- **random_image_generator.py**: A class that uses a prompt generation model and a suite of diffusion models to produce synthetic images. Supports caching of image/prompt pairs to a local directory.
\ No newline at end of file
diff --git a/docs/Incentive.md b/docs/Incentive.md
index 25939b34..96bbaa59 100644
--- a/docs/Incentive.md
+++ b/docs/Incentive.md
@@ -12,13 +12,6 @@ This document covers the current state of SN34's incentive mechanism.
Miner rewards are a weighted combination of their performance on video and image detection challenges. Validators keep track of miner performance using a score vector, which is updated using an exponential moving average. These scores are used by validators to set weights for miners, which determine their reward distribution, incentivizing high-quality predictions and consistent performance.
-
-
-
-Simulation applying our latest iteration of our incentive mechanism on historical subnet data. Note that this graphic shows incentive changes at a much more granular timescale (one timestep per challenge) than that of actual weight setting (once per 360 blocks)
incentive-simulator repository
-
-
-
## Rewards
>A miner's total reward $C$ combines their performance across both image and video challenges, weighted by configurable parameters $p$ that controls the emphasis placed on each modality.
diff --git a/docs/Mining.md b/docs/Mining.md
index d4b31f69..18ec7038 100644
--- a/docs/Mining.md
+++ b/docs/Mining.md
@@ -1,20 +1,13 @@
-# Miner Guide
-
-## Table of Contents
-
-1. [Installation 🔧](#installation)
- - [Data 📊](#data)
- - [Registration ✍️](#registration)
-2. [Mining ⛏️](#mining)
-3. [Training 🚂](#training)
+# Miner Setup Guide
## Before you proceed ⚠️
-**IMPORTANT**: If you are new to Bittensor, we recommend familiarizing yourself with the basics on the [Bittensor Website](https://bittensor.com/) before proceeding.
+If you are new to Bittensor, we recommend familiarizing yourself with the basics in the [Bittensor Docs](https://docs.bittensor.com/) before proceeding.
-**Ensure you are running Subtensor locally** to minimize outages and improve performance. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary).
+**Run your own local subtensor** to avoid rate limits set on public endpoints. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary) for setup instructions.
+
+**Understand your minimum compute requirements** for model training and miner deployment, which varies depending on your choice of model. You will likely need at least a consumer grade GPU for training. Many models can be deploying in CPU-only environments for mining.
-**Be aware of the minimum compute requirements** for our subnet, detailed in [Minimum compute YAML configuration](../min_compute.yml). A GPU is required for training (unless you want to wait weeks for training to complete), but is not required for inference while running a miner.
## Installation
@@ -23,166 +16,86 @@ Download the repository and navigate to the folder.
git clone https://github.com/bitmind-ai/bitmind-subnet.git && cd bitmind-subnet
```
-We recommend using a Conda virtual environment to install the necessary Python packages.
-You can set up Conda with this [quick command-line install](https://docs.anaconda.com/free/miniconda/#quick-command-line-install). Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization.
+We recommend using a Conda virtual environment to install the necessary Python packages.
+- You can set up Conda with this [quick command-line install](https://docs.anaconda.com/free/miniconda/#quick-command-line-install).
+- Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization.
-With miniconda installed, you can create a virtual environment with this command:
+With miniconda installed, you can create your virtual environment with this command:
```bash
-conda create -y -n bitmind python=3.10 ipython jupyter ipykernel
+conda create -y -n bitmind python=3.10
```
-To activate your virtual environment, run `conda activate bitmind`. To deactivate, `conda deactivate`.
-
-Install the remaining necessary requirements with the following chained command. This may take a few minutes to complete.
+- Activating your virtual environment: `conda activate bitmind`
+- Deactivating your virtual environment `conda deactivate`
+Install the remaining necessary requirements with the following chained command.
```bash
conda activate bitmind
export PIP_NO_CACHE_DIR=1
-chmod +x setup_env.sh
-./setup_env.sh
+chmod +x setup.sh
+./setup.sh
```
-### Data
+Before you register a miner on testnet or mainnet, you must first fill out all the necessary fields in `.env.miner`. Make a copy of the template, and fill in your wallet and axon information.
-*Only for training -- deployed miner instances do not require access to these datasets.*
-
-You can optionally pre-download the training datasets by running:
-
-```bash
-python base_miner/datasets/download_data.py
+```
+cp .env.miner.template .env.miner
```
-Feel free to skip this step - datasets will be downloaded automatically when you run the training scripts.
-The default list of datasets and default download location are defined in `base_miner/config.py`
+## Miner Task
+### Expected Miner Outputs
-## Registration
+> Miners respond to validator queries with a probability vector [$p_{real}$, $p_{synthetic}$, $p_{semisynthetic}$]
-To mine on our subnet, you must have a registered hotkey.
+Your task as a SN34 miner is to classify images and videos as real, synthetic, or semisynthetic.
+- **Real**: Authentic meida, not touched in any way by AI
+- **Synthetic**: Fully AI-generated media
+- **Semisynthetic**: AI-modified (spatially, not temporally) media. E.g. faceswaps, inpainting, etc.
-*Note: For testnet tao, you can make requests in the [Bittensor Discord's "Requests for Testnet Tao" channel](https://discord.com/channels/799672011265015819/1190048018184011867)*
+Minor details:
+- You are scored only on correctness, so rounding these probabilities will not give you extra incentive.
+- To maximize incentive, you must respond with the multiclass vector described above.
+ - If your classifier returns a binary response (e.g. a float in $[0., 1.]$ or a vector [$p_{real}$, $p_{synthetic}$]), you will earn partial credit (as defined by our incentive mechanism)
-To reduce the risk of deregistration due to technical issues or a poor performing model, we recommend the following:
-1. Test your miner on testnet before you start mining on mainnet.
-2. Before registering your hotkey on mainnet, make sure your port is open by running `curl your_ip:your_port`
-3. If you've trained a custom model, test it's performance by deploying to testnet. You can use this [notebook](https://github.com/BitMind-AI/bitmind-utils/blob/main/wandb_data/wandb_miner_performance.ipynb) to query our tesnet Weights and Biases logs and compute your model's accuracy. Our testnet validator is running 24/7.
+### Training your Detector
-#### Mainnet
+> [!IMPORTANT]
+> The default video and image detection models provided in `neurons/miner.py` serve only to exemplify the desired behavior of the miner neuron, and will not provide competitive performance on mainnet.
-```bash
-btcli s register --netuid 34 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network finney
-```
+#### Model
-#### Testnet
-
-```bash
-btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network test
-```
+#### Data
-## Mining
-You can launch your validator with `run_neuron.py`.
+## Registration
-First, make sure to update `validator.env` with your **wallet**, **hotkey**, and **miner port**. This file was created for you during setup, and is not tracked by git.
+To run a miner, you must have a registered hotkey.
+> [!IMPORTANT]
+> Registering on a Bittensor subnet burns TAO. To reduce the risk of deregistration due to technical issues or a poor performing model, we recommend the following:
+> 1. Test your miner on testnet before you start mining on mainnet.
+> 2. Before registering your hotkey on mainnet, make sure your axon port is accepting incoming traffic by running `curl your_ip:your_port`
-```bash
-IMAGE_DETECTOR=CAMO # Options: CAMO, UCF, NPR, None
-IMAGE_DETECTOR_CONFIG=camo.yaml # Configs live in base_miner/deepfake_detectors/configs
- # Supply a filename or relative path
-
-VIDEO_DETECTOR=TALL # Options: TALL, None
-VIDEO_DETECTOR_CONFIG=tall.yaml # Configs live in base_miner/deepfake_detectors/configs
- # Supply a filename or relative path
-
-IMAGE_DETECTOR_DEVICE=cpu # Options: cpu, cuda
-VIDEO_DETECTOR_DEVICE=cpu
-
-# Subtensor Network Configuration:
-NETUID=34 # Network User ID options: 34, 168
-SUBTENSOR_NETWORK=finney # Networks: finney, test, local
-SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
- # Endpoints:
- # - wss://entrypoint-finney.opentensor.ai:443
- # - wss://test.finney.opentensor.ai:443/
-
-# Wallet Configuration:
-WALLET_NAME=default
-WALLET_HOTKEY=default
-
-# Miner Settings:
-MINER_AXON_PORT=8091
-BLACKLIST_FORCE_VALIDATOR_PERMIT=True # Default setting to force validator permit for blacklisting
-```
-Now you're ready to run your miner!
+#### Mainnet
```bash
-conda activate bitmind
-pm2 start run_neuron.py -- --miner
-```
-
-- Auto updates are enabled by default. To disable, run with `--no-auto-updates`.
-- Self-healing restarts are enabled by default (every 6 hours). To disable, run with `--no-self-heal`.
-
-If you want to outperform the base model, you'll need to train on more data or try experiment with different hyperparameters and model architectures. See our [training](#train) section below for more details.
-
-
-## Training
-
-To see performance improvements over the base models, you'll need to train on more data, modify hyperparameters, or try a different modeling strategy altogether. Happy experimenting!
-
-*We are working on a unified interface for training models, but for now, each model has its own training script and logging systems that are functionality independent.*
-
-### NPR
-```python
-cd base_miner/NPR/ && python train_detector.py
-```
-The model with the lowest validation accuracy will be saved to `base_miner/NPR/checkpoints//model_epoch_best.pth`.
-
-### UCF
-```python
-cd base_miner/DFB/ && python train_detector.py --detector [UCF, TALL] --modality [image, video]
+btcli s register --netuid 34 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network finney
```
-The model with the lowest validation accuracy will be saved to `base_miner/UCF/logs/training//`.
-
-In this directory, you will find your model weights (`ckpt_best.pth`) and training configuration (`config.yaml`). Note that
-the training config, e.g. `config.yaml`, is different from the detector config, e.g. `ucf.yaml`.
-
-
-## Deploy Your Model
-Whether you have trained your own model, designed your own ``DeepfakeDetector`` subclass, or want to deploy a base miner using provided detectors in ``base_miner/deepfake_detectors/``, you can simply update the `miner.env` file to point to the desired detector class and config.
-
-We recommend consulting the `README` in `base_miner/` to learn about the extensibility and modular design of our base miner detectors.
-
-- The detector type (e.g. `UCF`) corresponds to the module name of the ``DeepfakeDetector`` subclass registered in ``base_miner/registry.py``'s ``DETECTOR_REGISTRY``.
-- The associated detector config file (e.g., `ucf.yaml`) lives in `base_miner/deepfake_detectors/configs/`.
- - *For UCF only:* You will need to set the `train_config` field in the detector configuration file (`base_miner/deepfake_detectors/configs/ucf.yaml`) to point to the training configuration file. This allows the instantiation of `UCFDetector` to use the settings from training time to reconstruct the correct model architecture. After training a model, the training config can be found in `base_miner/UCF/logs//config.yaml`. Feel free to move this to a different location, as long as the `train_config` field in `configs/ucf.yaml` reflects this.
-- The model weights file (e.g., `ckpt_best.pth`) should be placed in `base_miner//weights`.
- - If the weights specified in the config file do not exist, the miner will attempt to automatically download them from Hugging Face as specified by the `hf_repo` field in the config file. Feel free to use your own Hugging Face repository for hosting your model weights, and update the config file accordingly.
-
-
-
-## Tensorboard
-
-Training metrics are logged with TensorboardX. You can view interactive graphs of these metrics by starting a tensorboard server with the following command, and navigating to `localhost:6006`.
+#### Testnet
-```bash
-tensorboard --logdir=./base_miner/checkpoints/
-```
+> For testnet tao, you can make requests in the [Bittensor Discord's "Requests for Testnet Tao" channel](https://discord.com/channels/799672011265015819/1190048018184011867)
-If you're using remote compute for training, you can set up port forwarding by ssh'ing onto your machine with the following flags:
```bash
-ssh -L 7007:localhost:6006 your_username@your_ip
+btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey] --subtensor.network test
```
-with port forwarding enabled, you can start your tensorboard server on your remote machine with the following command, and view the tensorboard UI at `localhost:7007` in your local browser.
+#### Mining
-```bash
-tensorboard --logdir=./base_miner/checkpoints/ --host 0.0.0.0 --port 6006
-```
+You can now launch your miner with `start_miner.sh`, which will use the configuration you provided in `.env.miner` (see the last step of the [Installation](#installation) section).
diff --git a/docs/Validating.md b/docs/Validating.md
index 26a79d8a..02dcfbfd 100644
--- a/docs/Validating.md
+++ b/docs/Validating.md
@@ -1,17 +1,29 @@
# Validator Guide
-## Table of Contents
+## Before you proceed ⚠️
-1. [Installation 🔧](#installation)
- - [Data 📊](#data)
- - [Registration ✍️](#registration)
-2. [Validating ✅](#validating)
+If you are new to Bittensor (you're probably not if you're reading the validator guide 😎), we recommend familiarizing yourself with the basics in the [Bittensor Docs](https://docs.bittensor.com/) before proceeding.
-## Before you proceed ⚠️
+**Run your own local subtensor** to avoid rate limits set on public endpoints. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary) for setup instructions.
+
+**Understand the minimum compute requirements to run a validator**. Validator neurons on SN34 run a suite of generative (text-to-image, text-to-video, etc.) models that require an **80GB VRAM GPU**. They also maintain a large cache of real and synthetic media to ensure diverse, locally available data for challenging miners. We recommend **1TB of storage**. For more details, please see our [minimum compute documentation](../min_compute.yml)
+
+## Required Hugging Face Model Access
-**Ensure you are running Subtensor locally** to minimize outages and improve performance. See [Run a Subtensor Node Locally](https://github.com/opentensor/subtensor/blob/main/docs/running-subtensor-locally.md#compiling-your-own-binary).
+To properly validate, you must gain access to several Hugging Face models used by the subnet. This requires logging in to your Hugging Face account and accepting the terms for each model below:
-**Be aware of the minimum compute requirements** for our subnet, detailed in [Minimum compute YAML configuration](../min_compute.yml).
+- [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
+- [DeepFloyd IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)
+- [DeepFloyd IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0)
+
+> **Note:** Accepting the terms for any one of the DeepFloyd IF models (e.g., IF-II-L or IF-I-XL) will grant you access to all DeepFloyd IF models.
+>
+> **If you've been validating with us for a while (prior to V3), you've likely already gotten access to these models and can disregard this step.**
+
+To do this:
+1. Log in to your Hugging Face account.
+2. Visit each model page above.
+3. Click the "Access repository" or "Agree and access repository" button to accept the terms.
## Installation
@@ -20,22 +32,31 @@ Download the repository and navigate to the folder.
git clone https://github.com/bitmind-ai/bitmind-subnet.git && cd bitmind-subnet
```
-We recommend using a Conda virtual environment to install the necessary Python packages.
-You can set up Conda with this [quick command-line install](https://docs.anaconda.com/free/miniconda/#quick-command-line-install), and create a virtual environment with this command:
+We recommend using a Conda virtual environment to install the necessary Python packages.
+- You can set up Conda with this [quick command-line install](https://www.anaconda.com/docs/getting-started/miniconda/install#linux).
+- Note that after you run the last commands in the miniconda setup process, you'll be prompted to start a new shell session to complete the initialization.
+
+With miniconda installed, you can create your virtual environment with this command:
```bash
conda create -y -n bitmind python=3.10
```
-To activate your virtual environment, run `conda activate bitmind`. To deactivate, `conda deactivate`.
-
-Install the remaining necessary requirements with the following chained command.
+- Activating your virtual environment: `conda activate bitmind`
+- Deactivating your virtual environment `conda deactivate`
+Install the remaining necessary requirements with the following chained command.
```bash
conda activate bitmind
export PIP_NO_CACHE_DIR=1
-chmod +x setup_env.sh
-./setup_env.sh
+chmod +x setup.sh
+./setup.sh
+```
+
+Before you register, you should first fill out all the necessary fields in `.env.validator`. Make a copy of the template, and fill in your wallet information.
+
+```
+cp .env.validator.template .env.validator
```
## Registration
@@ -57,31 +78,11 @@ btcli s register --netuid 168 --wallet.name [wallet_name] --wallet.hotkey [walle
## Validating
-You can launch your validator with `run_neuron.py`.
+Before starting your validator, please ensure you've populated the empty fields in `.env.validator`, including `WANDB_API_KEY` and `HUGGING_FACE_TOKEN`.
-First, make sure to update `validator.env` with your **wallet**, **hotkey**, and **validator port**. This file was created for you during setup, and is not tracked by git.
-
-```bash
-NETUID=34 # Network User ID options: 34, 168
-SUBTENSOR_NETWORK=finney # Networks: finney, test, local
-SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
- # Endpoints:
- # - wss://entrypoint-finney.opentensor.ai:443
- # - wss://test.finney.opentensor.ai:443/
-
-# Wallet Configuration:
-WALLET_NAME=default
-WALLET_HOTKEY=default
-
-# Note: If you're using RunPod, you must select a port >= 70000 for symmetric mapping
-# Validator Port Setting:
-VALIDATOR_AXON_PORT=8092
-VALIDATOR_PROXY_PORT=10913
-DEVICE=cuda
-
-# API Keys:
-WANDB_API_KEY=your_wandb_api_key_here
-HUGGING_FACE_TOKEN=your_hugging_face_token_here
+If you haven't already, you can start by copying the template,
+```
+cp .env.validator.template .env.validator
```
If you don't have a W&B API key, please reach out to the BitMind team via Discord and we can provide one.
@@ -90,24 +91,23 @@ Now you're ready to run your validator!
```bash
conda activate bitmind
-pm2 start run_neuron.py -- --validator
+./start_validator.sh
```
+
- Auto updates are enabled by default. To disable, run with `--no-auto-updates`.
- Self-healing restarts are enabled by default (every 6 hours). To disable, run with `--no-self-heal`.
-The above command will kick off 4 `pm2` processes
+The above command will kick off 3 `pm2` processes
```
-┌────┬───────────────────────────┬─────────────┬─────────┬─────────┬──────────┬────────┬──────┬───────────┬──────────┬──────────┬──────────┬──────────┐
-│ id │ name │ namespace │ version │ mode │ pid │ uptime │ ↺ │ status │ cpu │ mem │ user │ watching │
-├────┼───────────────────────────┼─────────────┼─────────┼─────────┼──────────┼────────┼──────┼───────────┼──────────┼──────────┼──────────┼──────────┤
-│ 2 │ bitmind_cache_updater │ default │ N/A │ fork │ 1601308 │ 2h │ 0 │ online │ 0% │ 843.6mb │ user │ disabled │
-│ 3 │ bitmind_data_generator │ default │ N/A │ fork │ 1601426 │ 2h │ 0 │ online │ 0% │ 11.3gb │ user │ disabled │
-│ 1 │ bitmind_validator │ default │ N/A │ fork │ 1601246 │ 2h │ 0 │ online │ 0% │ 867.8mb │ user │ disabled │
-│ 0 │ run_neuron │ default │ N/A │ fork │ 223218 │ 41h │ 0 │ online │ 0% │ 8.9mb │ user │ disabled │
-└────┴───────────────────────────┴─────────────┴─────────┴─────────┴──────────┴────────┴──────┴───────────┴──────────┴──────────┴──────────┴──────────┘
+┌────┬───────────────────┬─────────────┬─────────┬─────────┬──────────┬────────┬──────┬───────────┬──────────┬──────────┬──────────┬──────────┐
+│ id │ name │ namespace │ version │ mode │ pid │ uptime │ ↺ │ status │ cpu │ mem │ user │ watching │
+├────┼───────────────────┼─────────────┼─────────┼─────────┼──────────┼────────┼──────┼───────────┼──────────┼──────────┼──────────┼──────────┤
+│ 0 │ sn34-generator │ default │ N/A │ fork │ 2397505 │ 38m │ 2 │ online │ 100% │ 3.0gb │ user │ disabled │
+│ 2 │ sn34-proxy │ default │ N/A │ fork │ 2398000 │ 27m │ 1 │ online │ 0% │ 695.2mb │ user │ disabled │
+│ 1 │ sn34-validator │ default │ N/A │ fork │ 2394939 │ 108m │ 0 │ online │ 0% │ 5.8gb │ user │ disabled │
+└────┴───────────────────┴─────────────┴─────────┴─────────┴──────────┴────────┴──────┴───────────┴──────────┴──────────┴──────────┴──────────┘
```
-- `run_neuron` manages self heals and auto updates
-- `bitmind_validator` is the validator process, whose hotkey, port, etc. are configured in `validator.env`
-- `bitmind_data_generator` runs our data generation pipeline to produce **synthetic images and videos** (stored in `~/.cache/sn34/synthetic`)
-- `bitmind_cache_updater` manages the cache of **real images and videos** (stored in `~/.cache/sn34/real`)
+- `sn34-validator` is the validator process
+- `sn34-generator` runs our data generation pipeline to produce **synthetic images and videos** (stored in `~/.cache/sn34`)
+- `sn34-proxy`routes organic traffic from our applications to miners.
diff --git a/static/Bitmind-Logo.png b/docs/static/Bitmind-Logo.png
similarity index 100%
rename from static/Bitmind-Logo.png
rename to docs/static/Bitmind-Logo.png
diff --git a/static/Join-BitMind-Discord.png b/docs/static/Join-BitMind-Discord.png
similarity index 100%
rename from static/Join-BitMind-Discord.png
rename to docs/static/Join-BitMind-Discord.png
diff --git a/static/Subnet-Arch.png b/docs/static/Subnet-Arch.png
similarity index 100%
rename from static/Subnet-Arch.png
rename to docs/static/Subnet-Arch.png
diff --git a/min_compute.yml b/min_compute.yml
index 5254a94d..845f5c50 100644
--- a/min_compute.yml
+++ b/min_compute.yml
@@ -1,51 +1,12 @@
-# Use this document to specify the minimum compute requirements.
-# This document will be used to generate a list of recommended hardware for your subnet.
+# NOTE FOR MINERS:
+# Miner min compute varies based on selected model architecture.
+# For model training, you will most likely need a GPU. For miner deployment, depending
+# on your model, you may be able to get away with CPU.
-# This is intended to give a rough estimate of the minimum requirements
-# so that the user can make an informed decision about whether or not
-# they want to run a miner or validator on their machine.
-
-# NOTE: Specification for miners may be different from validators
-
-version: '1.1' # update this version key as needed, ideally should match your release version
+version: '3.0.0'
compute_spec:
- miner:
-
- cpu:
- min_cores: 2 # Minimum number of CPU cores
- min_speed: 2.5 # Minimum speed per core (GHz)
- recommended_cores: 4 # Recommended number of CPU cores
- recommended_speed: 3.5 # Recommended speed per core (GHz)
- architecture: "x86_64" # Architecture type (e.g., x86_64, arm64)
-
- gpu:
- required: True # Does the application require a GPU?
- min_vram: 8 # Minimum GPU VRAM (GB)
- recommended_vram: 8 # Recommended GPU VRAM (GB)
- cuda_cores: 1920 # Minimum number of CUDA cores (if applicable)
- min_compute_capability: 6.1 # Minimum CUDA compute capability
- recommended_compute_capability: 6.1 # Recommended CUDA compute capability
- recommended_gpu: "NVIDIA GTX 1070" # provide a recommended GPU to purchase/rent
-
- memory:
- min_ram: 8 # Minimum RAM (GB)
- min_swap: 4 # Minimum swap space (GB)
- recommended_swap: 8 # Recommended swap space (GB)
- ram_type: "DDR4" # RAM type (e.g., DDR4, DDR3, etc.)
-
- storage:
- min_space: 100 # Minimum free storage space (GB)
- recommended_space: 200 # Recommended free storage space (GB)
- type: "SSD" # Preferred storage type (e.g., SSD, HDD)
- min_iops: 1000 # Minimum I/O operations per second (if applicable)
- recommended_iops: 5000 # Recommended I/O operations per second
-
- os:
- name: "Ubuntu" # Name of the preferred operating system(s)
- version: 22.04 # Version of the preferred operating system(s)
-
validator:
cpu:
@@ -82,7 +43,7 @@ compute_spec:
ram_type: "DDR6" # RAM type (e.g., DDR4, DDR3, etc.)
storage:
- min_space: 600 # Minimum free storage space (GB)
+ min_space: 1000 # Minimum free storage space (GB)
recommended_space: 1000 # Recommended free storage space (GB)
type: "SSD" # Preferred storage type (e.g., SSD, HDD)
min_iops: 1000 # Minimum I/O operations per second (if applicable)
diff --git a/neurons/README.md b/neurons/README.md
deleted file mode 100644
index fd76958c..00000000
--- a/neurons/README.md
+++ /dev/null
@@ -1,8 +0,0 @@
-## CAMO Base Miner
-
-In version 1.1.0, our base miner incorporates the Content-Aware Model Orchestration (CAMO) framework to enhance deepfake detection.
-This framework utilizes a mixture-of-experts system, combining both generalist models and specialized expert models to improve the accuracy of deepfake identification.
-
-Read more here:
-
-https://bitmindlabs.notion.site/CAMO-Content-Aware-Model-Orchestration-CAMO-Framework-for-Deepfake-Detection-43ef46a0f9de403abec7a577a45cd075
diff --git a/neurons/base.py b/neurons/base.py
new file mode 100644
index 00000000..1450d081
--- /dev/null
+++ b/neurons/base.py
@@ -0,0 +1,147 @@
+import argparse
+from threading import Thread
+from typing import Callable, List
+import bittensor as bt
+import copy
+import inspect
+import traceback
+
+from bittensor.core.settings import SS58_FORMAT, TYPE_REGISTRY
+from nest_asyncio import asyncio
+from substrateinterface import SubstrateInterface
+import signal
+
+from bitmind import (
+ __spec_version__ as spec_version,
+)
+from bitmind.metagraph import run_block_callback_thread
+from bitmind.types import NeuronType
+from bitmind.utils import ExitContext, on_block_interval
+from bitmind.config import (
+ add_args,
+ add_validator_args,
+ add_miner_args,
+ add_proxy_args,
+ validate_config_and_neuron_path,
+)
+
+
+class BaseNeuron:
+ config: "bt.config"
+ neuron_type: NeuronType
+ exit_context = ExitContext()
+ next_sync_block = None
+ block_callbacks: List[Callable] = []
+ substrate_thread: Thread
+
+ def check_registered(self):
+ if not self.subtensor.is_hotkey_registered(
+ netuid=self.config.netuid,
+ hotkey_ss58=self.wallet.hotkey.ss58_address,
+ ):
+ bt.logging.error(
+ f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}."
+ f" Please register the hotkey using `btcli subnets register` before trying again"
+ )
+ exit()
+
+ @on_block_interval("epoch_length")
+ async def maybe_sync_metagraph(self, block):
+ self.check_registered()
+ bt.logging.info("Resyncing Metagraph")
+ self.metagraph.sync(subtensor=self.subtensor)
+
+ if self.neuron_type == NeuronType.VALIDATOR:
+ bt.logging.info("Metagraph updated, re-syncing hotkeys and moving averages")
+ self.eval_engine.sync_to_metagraph()
+
+ async def run_callbacks(self, block):
+ if (
+ hasattr(self, "initialization_complete")
+ and not self.initialization_complete
+ ):
+ bt.logging.debug(
+ f"Skipping callbacks at block {block} during initialization"
+ )
+ return
+
+ for callback in self.block_callbacks:
+ try:
+ res = callback(block)
+ if inspect.isawaitable(res):
+ await res
+ except Exception as e:
+ bt.logging.error(
+ f"Failed running callback {callback.__name__}: {str(e)}"
+ )
+ bt.logging.error(traceback.format_exc())
+
+ def __init__(self, config=None):
+ bt.logging.info(
+ f"Bittensor Version: {bt.__version__} | SN34 Version {spec_version}"
+ )
+
+ parser = argparse.ArgumentParser()
+ bt.wallet.add_args(parser)
+ bt.subtensor.add_args(parser)
+ bt.logging.add_args(parser)
+ add_args(parser)
+
+ if self.neuron_type == NeuronType.VALIDATOR:
+ bt.axon.add_args(parser)
+ add_validator_args(parser)
+ if self.neuron_type == NeuronType.VALIDATOR_PROXY:
+ add_validator_args(parser)
+ add_proxy_args(parser)
+ if self.neuron_type == NeuronType.MINER:
+ bt.axon.add_args(parser)
+ add_miner_args(parser)
+
+ self.config = bt.config(parser)
+ if config:
+ base_config = copy.deepcopy(config)
+ self.config.merge(base_config)
+
+ validate_config_and_neuron_path(self.config)
+
+ ## Add kill signals
+ signal.signal(signal.SIGINT, self.exit_context.startExit)
+ signal.signal(signal.SIGTERM, self.exit_context.startExit)
+
+ ## LOGGING
+ bt.logging(config=self.config, logging_dir=self.config.neuron.full_path)
+ bt.logging.set_info()
+ if self.config.logging.debug:
+ bt.logging.set_debug(True)
+ if self.config.logging.trace:
+ bt.logging.set_trace(True)
+
+ ## BITTENSOR INITIALIZATION
+ bt.logging.success(self.config)
+ self.wallet = bt.wallet(config=self.config)
+ self.subtensor = bt.subtensor(
+ config=self.config, network=self.config.subtensor.chain_endpoint
+ )
+ self.metagraph = self.subtensor.metagraph(self.config.netuid)
+
+ self.loop = asyncio.get_event_loop()
+ bt.logging.debug(f"Wallet: {self.wallet}")
+ bt.logging.debug(f"Subtensor: {self.subtensor}")
+ bt.logging.debug(f"Metagraph: {self.metagraph}")
+
+ ## CHECK IF REGG'D
+ self.check_registered()
+ self.uid = self.metagraph.hotkeys.index(self.wallet.hotkey.ss58_address)
+
+ ## Substrate, Subtensor and Metagraph
+ self.substrate = SubstrateInterface(
+ ss58_format=SS58_FORMAT,
+ use_remote_preset=True,
+ url=self.config.subtensor.chain_endpoint,
+ type_registry=TYPE_REGISTRY,
+ )
+
+ self.block_callbacks.append(self.maybe_sync_metagraph)
+ self.substrate_thread = run_block_callback_thread(
+ self.substrate, self.run_callbacks
+ )
diff --git a/neurons/generator.py b/neurons/generator.py
new file mode 100644
index 00000000..fe2f1e9a
--- /dev/null
+++ b/neurons/generator.py
@@ -0,0 +1,264 @@
+import asyncio
+import sys
+import io
+import time
+import signal
+import traceback
+import argparse
+from pathlib import Path
+from PIL import Image
+from typing import List, Dict, Any
+import os
+import atexit
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
+os.environ["TRANSFORMERS_VERBOSITY"] = "error"
+os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
+os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
+
+import warnings
+
+for module in ["diffusers", "transformers.tokenization_utils_base"]:
+ warnings.filterwarnings("ignore", category=FutureWarning, module=module)
+
+import logging
+
+logging.getLogger("transformers").setLevel(logging.ERROR)
+logging.getLogger("diffusers").setLevel(logging.ERROR)
+logging.getLogger("torch").setLevel(logging.ERROR)
+logging.getLogger("datasets").setLevel(logging.ERROR)
+
+import transformers
+
+transformers.logging.set_verbosity_error()
+
+import bittensor as bt
+from bitmind.config import add_args, add_data_generator_args
+from bitmind.utils import ExitContext, get_metadata
+from bitmind.wandb_utils import init_wandb
+from bitmind.types import CacheConfig, MediaType, Modality
+from bitmind.cache.sampler import ImageSampler
+from bitmind.generation import (
+ GenerationPipeline,
+ initialize_model_registry,
+)
+
+
+class Generator:
+ def __init__(self):
+ self.exit_context = ExitContext()
+ self.task = None
+ self.generation_pipeline = None
+ self.image_sampler = None
+
+ self.setup_signal_handlers()
+ atexit.register(self.cleanup)
+
+ parser = argparse.ArgumentParser()
+ bt.subtensor.add_args(parser)
+ bt.wallet.add_args(parser)
+ bt.logging.add_args(parser)
+ add_data_generator_args(parser)
+ add_args(parser)
+
+ self.config = bt.config(parser)
+
+ bt.logging(config=self.config, logging_dir=self.config.neuron.full_path)
+ bt.logging.set_trace()
+ if self.config.logging.debug:
+ bt.logging.set_debug(True)
+ if self.config.logging.trace:
+ bt.logging.set_trace(True)
+
+ bt.logging.success(self.config)
+ wallet_configured = (
+ self.config.wallet.name is not None
+ and self.config.wallet.hotkey is not None
+ )
+ if wallet_configured and not self.config.wandb_off:
+ try:
+ self.wallet = bt.wallet(config=self.config)
+ self.uid = (
+ bt.subtensor(
+ config=self.config, network=self.config.subtensor.chain_endpoint
+ )
+ .metagraph(self.config.netuid)
+ .hotkeys.index(self.wallet.hotkey.ss58_address)
+ )
+ self.wandb_run = init_wandb(
+ self.config.copy(),
+ self.config.wandb.process_name,
+ self.uid,
+ self.wallet.hotkey,
+ )
+
+ except Exception as e:
+ bt.logging.error("Not registered, can't sign W&B run")
+ bt.logging.error(e)
+ self.config.wandb.off = True
+
+ def setup_signal_handlers(self):
+ signal.signal(signal.SIGINT, self.signal_handler)
+ signal.signal(signal.SIGTERM, self.signal_handler)
+ signal.signal(signal.SIGQUIT, self.signal_handler)
+
+ def signal_handler(self, sig, frame):
+ signal_name = signal.Signals(sig).name
+ bt.logging.info(f"Received {signal_name}, initiating shutdown...")
+ self.cleanup()
+ sys.exit(0)
+
+ def cleanup(self):
+ if self.task and not self.task.done():
+ self.task.cancel()
+
+ if self.generation_pipeline:
+ try:
+ bt.logging.trace("Shutting down generator...")
+ self.generation_pipeline.shutdown()
+ bt.logging.success("Generator shut down gracefully")
+ except Exception as e:
+ bt.logging.error(f"Error during generator shutdown: {e}")
+
+ # Force cleanup of any GPU memory
+ try:
+ import torch
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ bt.logging.trace("CUDA memory cache cleared")
+ except Exception as e:
+ pass
+
+ async def wait_for_cache(self, timeout: int = 300):
+ """Wait for the cache to be populated with images for prompt generation"""
+ start = time.time()
+ attempts = 0
+ while True:
+ if time.time() - start > timeout:
+ return False
+
+ available_count = self.image_sampler.get_available_count(use_index=False)
+ if available_count > 0:
+ return True
+
+ await asyncio.sleep(10)
+ if not attempts % 3:
+ bt.logging.info("Waiting for images in cache...")
+ attempts += 1
+
+ async def sample_images(self, k: int = 1) -> List[Dict[str, Any]]:
+ """Sample images from the cache"""
+ result = await self.image_sampler.sample(k, remove_from_cache=False)
+ if result["count"] == 0:
+ raise ValueError("No images available in cache")
+
+ # Convert bytes to PIL images
+ for item in result["items"]:
+ if isinstance(item["image"], bytes):
+ item["image"] = Image.open(io.BytesIO(item["image"]))
+
+ return result["items"]
+
+ async def run(self):
+ """Main generator loop"""
+ try:
+ cache_dir = self.config.cache_dir
+ batch_size = self.config.batch_size
+ device = self.config.device
+
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
+
+ self.image_sampler = ImageSampler(
+ CacheConfig(
+ modality=Modality.IMAGE.value,
+ media_type=MediaType.REAL.value,
+ base_dir=Path(cache_dir),
+ )
+ )
+
+ await self.wait_for_cache()
+ bt.logging.success("Cache populated. Proceeding to generation.")
+
+ model_registry = initialize_model_registry()
+ model_names = model_registry.get_interleaved_model_names(self.config.tasks)
+ bt.logging.info(f"Starting generator")
+ bt.logging.info(f"Tasks: {self.config.tasks}")
+ bt.logging.info(f"Models: {model_names}")
+
+ self.generation_pipeline = GenerationPipeline(
+ output_dir=cache_dir,
+ device=device,
+ )
+
+ gen_count = 0
+ batch_count = 0
+ while not self.exit_context.isExiting:
+ if asyncio.current_task().cancelled():
+ break
+
+ try:
+ image_samples = await self.sample_images(batch_size)
+ bt.logging.info(
+ f"Starting batch generation | Batch Size: {len(image_samples)} | Batch Count: {gen_count}"
+ )
+
+ start_time = time.time()
+
+ filepaths = self.generation_pipeline.generate(
+ image_samples, model_names=model_names
+ )
+ await asyncio.sleep(1)
+
+ duration = time.time() - start_time
+ gen_count += len(filepaths)
+ batch_count += 1
+ bt.logging.info(
+ f"Generated {len(filepaths)} files in batch #{batch_count} in {duration:.2f} seconds"
+ )
+
+ if not self.config.wandb.off:
+ if batch_count >= self.config.wandb.num_batches_per_run:
+ batch_count = 0
+ self.wandb_run.finish()
+ self.wandb_run = init_wandb(
+ self.config.copy(),
+ self.config.wandb.process_name,
+ self.uid,
+ self.wallet.hotkey,
+ )
+
+ except asyncio.CancelledError:
+ bt.logging.info("Task cancelled, exiting loop")
+ break
+ except Exception as e:
+ bt.logging.error(f"Error in batch processing: {e}")
+ bt.logging.error(traceback.format_exc())
+ await asyncio.sleep(10)
+ except Exception as e:
+ bt.logging.error(f"Unhandled exception in main task: {e}")
+ bt.logging.error(traceback.format_exc())
+ raise
+ finally:
+ self.cleanup()
+
+ def start(self):
+ """Start the generator"""
+ loop = asyncio.get_event_loop()
+ try:
+ self.task = asyncio.ensure_future(self.run())
+ loop.run_until_complete(self.task)
+ except KeyboardInterrupt:
+ bt.logging.info("Generator interrupted by KeyboardInterrupt, shutting down")
+ except Exception as e:
+ bt.logging.error(f"Unhandled exception: {e}")
+ bt.logging.error(traceback.format_exc())
+ finally:
+ self.cleanup()
+
+
+if __name__ == "__main__":
+ generator = Generator()
+ generator.start()
+ sys.exit(0)
diff --git a/neurons/miner.py b/neurons/miner.py
index e894bb14..455fb618 100644
--- a/neurons/miner.py
+++ b/neurons/miner.py
@@ -1,180 +1,414 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# developer: aliang322, benliang99, dubm
-# Copyright © 2023 Bitmind
+import io
+import os
+import time
+import traceback
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+import av
+import json
+import bittensor as bt
+import numpy as np
+import requests
+import tempfile
+import torch
+import uvicorn
+from bittensor.core.axon import FastAPIThreadedServer
+from bittensor.core.extrinsics.serving import serve_extrinsic
+from bittensor.core.settings import SS58_FORMAT, TYPE_REGISTRY
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
+from PIL import Image
+from torchvision import models
+from substrateinterface import SubstrateInterface
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
+from bitmind.epistula import verify_signature, EPISTULA_VERSION
+from bitmind.metagraph import (
+ run_block_callback_thread,
+)
+from bitmind.types import NeuronType
+from bitmind.utils import print_info
+from neurons.base import BaseNeuron
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-from PIL import Image
-import bittensor as bt
-import torch
-import base64
-import time
-import typing
-import io
-import os
-import sys
-import numpy as np
+class Detector:
+ def __init__(self, config):
+ self.config = config
+ self.image_detector = None
+ self.video_detector = None
+ self.device = (
+ self.config.device
+ if hasattr(self.config, "device")
+ else "cuda" if torch.cuda.is_available() else "cpu"
+ )
+ self.load_model()
+
+ def load_model(self, modality=None):
+ """Load the appropriate detection model based on modality.
+
+ MINER TODO:
+ This class has placeholder models to demonstrate the required outputs
+ for validator requests. They have not been trained and will perform
+ poorly. Your task is to train performant video and image detection
+ models and load them here. Happy mining!
+
+ Args:
+ modality (str): Type of detection model to load ('image' or 'video')
+ """
+ bt.logging.info(f"Loading {modality} detection model...")
+ if modality in ("image", None):
+ ### REPLACE WITH YOUR OWN MODEL
+ self.image_detector = models.resnet50(pretrained=True)
+ num_ftrs = self.image_detector.fc.in_features
+ self.image_detector.fc = torch.nn.Linear(num_ftrs, 3)
+ self.image_detector = self.image_detector.to(self.device)
+ self.image_detector.eval()
+
+ if modality in ("video", None):
+ ### REPLACE WITH YOUR OWN MODEL
+ self.video_detector = models.video.r3d_18(pretrained=True)
+ num_ftrs = self.video_detector.fc.in_features
+ self.video_detector.fc = torch.nn.Linear(num_ftrs, 3)
+ self.video_detector = self.video_detector.to(self.device)
+ self.video_detector.eval()
+
+ else:
+ raise ValueError(f"Unsupported modality: {modality}")
+
+ def preprocess(self, media_tensor, modality):
+ bt.logging.debug(
+ json.dumps(
+ {
+ "modality": "video",
+ "shape": tuple(media_tensor.shape),
+ "dtype": str(media_tensor.dtype),
+ "min": torch.min(media_tensor).item(),
+ "max": torch.max(media_tensor).item(),
+ },
+ indent=2,
+ )
+ )
+
+ if modality == "image":
+ media_tensor = media_tensor.unsqueeze(0).float().to(self.device)
+ elif modality == "video":
+ media_tensor = media_tensor.unsqueeze(0).float().to(self.device)
+ return media_tensor
+
+ def detect(self, media_tensor, modality):
+ """Perform inference with either self.video_detector or self.image_detector
+
+ MINER TODO: Update detection logic as necessary for your own model
+
+ Args:
+ tensor (torch.tensor): Input media tensor
+ modality (str): Type of detection to perform ('image' or 'video')
+
+ Returns:
+ torch.Tensor: Probability vector containing 3 class probabilities
+ [p_real, p_synthetic, p_semisynthetic]
+ """
+ media_tensor = self.preprocess(media_tensor, modality)
+
+ if modality == "image":
+ if self.image_detector is None:
+ self.load_model("image")
+
+ bt.logging.debug(
+ f"Running image detection on array shape {media_tensor.shape}"
+ )
+
+ # MINER TODO update detection logic as necessary
+ with torch.no_grad():
+ outputs = self.image_detector(media_tensor)
+ probs = torch.softmax(outputs, dim=1)[0]
+
+ elif modality == "video":
+ if self.video_detector is None:
+ self.load_model("video")
+
+ bt.logging.debug(
+ f"Running video detection on array shape {media_tensor.shape}"
+ )
+
+ # MINER TODO update detection logic as necessary
+ with torch.no_grad():
+ outputs = self.video_detector(media_tensor)
+ probs = torch.softmax(outputs, dim=1)[0]
-from base_miner.registry import DETECTOR_REGISTRY
-from base_miner.deepfake_detectors import NPRImageDetector, UCFImageDetector, CAMOImageDetector, TALLVideoDetector
-from bitmind.base.miner import BaseMinerNeuron
-from bitmind.protocol import ImageSynapse, VideoSynapse, decode_video_synapse
-from bitmind.utils.config import get_device
+ else:
+ raise ValueError(f"Unsupported modality: {modality}")
+
+ bt.logging.success(f"Prediction: {probs}")
+ return probs
-class Miner(BaseMinerNeuron):
+class Miner(BaseNeuron):
+ neuron_type = NeuronType.MINER
+ fast_api: FastAPIThreadedServer
+ initialization_complete: bool = False
def __init__(self, config=None):
- super(Miner, self).__init__(config=config)
- bt.logging.info("Attaching forward function to miner axon.")
- self.axon.attach(
- forward_fn=self.forward_image,
- blacklist_fn=self.blacklist_image,
- priority_fn=self.priority_image,
- ).attach(
- forward_fn=self.forward_video,
- blacklist_fn=self.blacklist_video,
- priority_fn=self.priority_video,
+ super().__init__(config)
+ bt.logging.set_info()
+ ## Typesafety
+ assert self.config.netuid
+ assert self.config.logging
+
+ self.detector = Detector(self.config)
+
+ # Register log callback
+ self.block_callbacks.append(self.log_on_block)
+
+ ## BITTENSOR INITIALIZATION
+ bt.logging.info(
+ "\N{GRINNING FACE WITH SMILING EYES}", "Successfully Initialized!"
)
- bt.logging.info(f"Axon created: {self.axon}")
-
- bt.logging.info("Loading image detection model if configured")
- self.load_image_detector()
- bt.logging.info("Loading video detection model if configured")
- self.load_video_detector()
-
- def load_image_detector(self):
- if (str(self.config.neuron.image_detector).lower() == 'none' or
- str(self.config.neuron.image_detector_config).lower() == 'none'):
- bt.logging.warning("No image detector configuration provided, skipping.")
- self.image_detector = None
- return
+ self.initialization_complete = True
+
+ def shutdown(self):
+ if self.fast_api:
+ self.fast_api.stop()
- if self.config.neuron.image_detector_device == 'auto':
- bt.logging.warning("Automatic device configuration enabled for image detector")
- self.config.neuron.image_detector_device = get_device()
-
- self.image_detector = DETECTOR_REGISTRY[self.config.neuron.image_detector](
- config_name=self.config.neuron.image_detector_config,
- device=self.config.neuron.image_detector_device
+ async def log_on_block(self, block):
+ print_info(
+ self.metagraph,
+ self.wallet.hotkey.ss58_address,
+ block,
)
- bt.logging.info(f"Loaded image detection model: {self.config.neuron.image_detector}")
- def load_video_detector(self):
- if (str(self.config.neuron.video_detector).lower() == 'none' or
- str(self.config.neuron.video_detector_config).lower() == 'none'):
- bt.logging.warning("No video detector configuration provided, skipping.")
- self.video_detector = None
- return
+ async def detect_image(self, request: Request):
+ content_type = request.headers.get("Content-Type", "application/octet-stream")
+ image_data = await request.body()
+
+ signed_by = request.headers.get("Epistula-Signed-By", "")[:8]
+ bt.logging.info(
+ "\u2713",
+ f"Received image ({len(image_data)} bytes) from {signed_by}, type: {content_type}",
+ )
- if self.config.neuron.video_detector_device == 'auto':
- bt.logging.warning("Automatic device configuration enabled for video detector")
- self.config.neuron.video_detector_device = get_device()
+ if content_type not in ("image/jpeg", "application/octet-stream"):
+ bt.logging.warning(
+ f"Unexpected content type: {content_type}, expected image/jpeg"
+ )
- self.video_detector = DETECTOR_REGISTRY[self.config.neuron.video_detector](
- config_name=self.config.neuron.video_detector_config,
- device=self.config.neuron.video_detector_device
+ try:
+ image_array = np.array(Image.open(io.BytesIO(image_data)))
+ image_tensor = torch.from_numpy(image_array).permute(2, 0, 1)
+
+ ### PREDICT - update the Detector class with your own model and preprocessing
+ pred = self.detector.detect(image_tensor, "image")
+ return {"status": "success", "prediction": pred.tolist()}
+
+ except Exception as e:
+ bt.logging.error(f"Error processing image: {e}")
+ bt.logging.error(traceback.format_exc())
+ return {"status": "error", "message": str(e)}
+
+ async def detect_video(self, request: Request):
+ content_type = request.headers.get("Content-Type", "application/octet-stream")
+ video_data = await request.body()
+ signed_by = request.headers.get("Epistula-Signed-By", "")[:8]
+ bt.logging.info(
+ f"Received video ({len(video_data)} bytes) from {signed_by}, type: {content_type}",
)
- bt.logging.info(f"Loaded video detection model: {self.config.neuron.video_detector}")
+ if content_type not in ("video/mp4", "video/mpeg", "application/octet-stream"):
+ bt.logging.warning(
+ f"Unexpected content type: {content_type}, expected video/mp4 or video/mpeg"
+ )
+ try:
+ with tempfile.NamedTemporaryFile(suffix=".mp4") as temp_file:
+ temp_path = temp_file.name
+ temp_file.write(video_data)
+ temp_file.flush()
- async def forward_image(
- self, synapse: ImageSynapse
- ) -> ImageSynapse:
- """
- Perform inference on image
+ with av.open(temp_path) as container:
+ video_stream = next(
+ (s for s in container.streams if s.type == "video"), None
+ )
+ if not video_stream:
+ raise ValueError("No video stream found")
+ try:
+ codec_info = (
+ f"name: {video_stream.codec.name}"
+ if hasattr(video_stream, "codec")
+ else "unknown"
+ )
+ bt.logging.info(f"Video codec: {codec_info}")
+ except Exception as codec_err:
+ bt.logging.warning(
+ f"Could not get codec info: {str(codec_err)}"
+ )
+ duration = container.duration / 1000000 if container.duration else 0
+ width = video_stream.width
+ height = video_stream.height
+ fps = video_stream.average_rate
+ bt.logging.info(
+ f"Video dimensions: ({width}, {height}), fps: {fps}, duration: {duration:.2f}s"
+ )
+ frames = []
+ for frame in container.decode(video=0):
+ img_array = frame.to_ndarray(format="rgb24")
+ frames.append(img_array)
+ bt.logging.info(f"Extracted {len(frames)} frames")
+ if not frames:
+ raise ValueError("No frames could be extracted from the video")
+ video_array = np.stack(frames)
+ video_tensor = torch.permute(
+ torch.from_numpy(video_array), (3, 0, 1, 2) # (C, T, H, W)
+ )
- Args:
- synapse (bt.Synapse): The synapse object containing the list of b64 encoded images in the
- 'images' field.
+ ### PREDICT - update the Detector class with your own model and preprocessing
+ pred = self.detector.detect(video_tensor, "video")
+ return {"status": "success", "prediction": pred.tolist()}
+ except Exception as e:
+ bt.logging.error(f"Error processing video: {str(e)}")
+ bt.logging.error(traceback.format_exc())
+ return {"status": "error", "message": str(e)}
- Returns:
- bt.Synapse: The synapse object with the 'predictions' field populated with a list of probabilities
+ async def determine_epistula_version_and_verify(self, request: Request):
+ version = request.headers.get("Epistula-Version")
+ if version == EPISTULA_VERSION:
+ await self.verify_request(request)
+ return
+ raise HTTPException(status_code=400, detail="Unknown Epistula version")
- """
- if self.image_detector is None:
- bt.logging.info("Image detection model not configured; skipping image challenge")
- else:
- bt.logging.info("Received image challenge!")
- try:
- image_bytes = base64.b64decode(synapse.image)
- image = Image.open(io.BytesIO(image_bytes))
- synapse.prediction = self.image_detector(image)
- except Exception as e:
- bt.logging.error("Error performing inference")
- bt.logging.error(e)
-
- bt.logging.info(f"PREDICTION = {synapse.prediction}")
- label = synapse.testnet_label
- if synapse.testnet_label != -1:
- bt.logging.info(f"LABEL (testnet only) = {label}")
- return synapse
-
- async def forward_video(
- self, synapse: VideoSynapse
- ) -> VideoSynapse:
- """
- Perform inference on video
- Args:
- synapse (bt.Synapse): The synapse object containing the list of b64 encoded images in the
- 'images' field.
+ async def verify_request(
+ self,
+ request: Request,
+ ):
+ bt.logging.debug("Verifying request")
+ # We do this as early as possible so that now has a lesser chance
+ # of causing a stale request
+ now = round(time.time() * 1000)
- Returns:
- bt.Synapse: The synapse object with the 'predictions' field populated with a list of probabilities
+ # We need to check the signature of the body as bytes
+ # But use some specific fields from the body
+ signed_by = request.headers.get("Epistula-Signed-By")
+ signed_for = request.headers.get("Epistula-Signed-For")
+ if signed_for != self.wallet.hotkey.ss58_address:
+ raise HTTPException(
+ status_code=400, detail="Bad Request, message is not intended for self"
+ )
+ if signed_by not in self.metagraph.hotkeys:
+ raise HTTPException(status_code=401, detail="Signer not in metagraph")
- """
- if self.video_detector is None:
- bt.logging.info("Video detection model not configured; skipping video challenge")
- else:
- bt.logging.info("Received video challenge!")
+ uid = self.metagraph.hotkeys.index(signed_by)
+ stake = self.metagraph.S[uid].item()
+ if not self.config.no_force_validator_permit and stake < 10000:
+ bt.logging.warning(
+ f"Blacklisting request from {signed_by} [uid={uid}], not enough stake -- {stake}"
+ )
+ raise HTTPException(status_code=401, detail="Stake below minimum: {stake}")
+
+ # If anything is returned here, we can throw
+ body = await request.body()
+ err = verify_signature(
+ request.headers.get("Epistula-Request-Signature"),
+ body,
+ request.headers.get("Epistula-Timestamp"),
+ request.headers.get("Epistula-Uuid"),
+ signed_for,
+ signed_by,
+ now,
+ )
+ if err:
+ bt.logging.error(err)
+ raise HTTPException(status_code=400, detail=err)
+
+ def run(self):
+ assert self.config.netuid
+ assert self.config.subtensor
+ assert self.config.axon
+
+ # Serve passes the axon information to the network + netuid we are hosting on.
+ # This will auto-update if the axon port of external ip have changed.
+ external_ip = self.config.axon.external_ip or self.config.axon.ip
+ if not external_ip or external_ip == "[::]":
try:
- frames_tensor = decode_video_synapse(synapse)
- frames_tensor = frames_tensor.to(self.config.neuron.video_detector_device)
- synapse.prediction = self.video_detector(frames_tensor)
- except Exception as e:
- bt.logging.error("Error performing inference")
- bt.logging.error(e)
+ external_ip = requests.get("https://checkip.amazonaws.com").text.strip()
+ except Exception:
+ bt.logging.error("Failed to get external IP")
- bt.logging.info(f"PREDICTION = {synapse.prediction}")
- label = synapse.testnet_label
- if synapse.testnet_label != -1:
- bt.logging.info(f"LABEL (testnet only) = {label}")
- return synapse
+ bt.logging.info(f"Serving miner endpoint {external_ip}:{self.config.axon.port}")
+ bt.logging.info(
+ f"Netowrk: {self.config.subtensor.chain_endpoint} | Netuid: {self.config.netuid}"
+ )
- async def blacklist_image(self, synapse: ImageSynapse) -> typing.Tuple[bool, str]:
- return await self.blacklist(synapse)
+ serve_success = serve_extrinsic(
+ subtensor=self.subtensor,
+ wallet=self.wallet,
+ ip=external_ip,
+ port=self.config.axon.port,
+ protocol=4,
+ netuid=self.config.netuid,
+ wait_for_finalization=True,
+ )
+ if not serve_success:
+ bt.logging.error("Failed to serve endpoint")
+ return
- async def blacklist_video(self, synapse: VideoSynapse) -> typing.Tuple[bool, str]:
- return await self.blacklist(synapse)
+ # Start starts the miner's endpoint, making it active on the network.
+ # change the config in the axon
+ app = FastAPI()
+ router = APIRouter()
+ router.add_api_route("/", ping, methods=["GET"])
- async def priority_image(self, synapse: ImageSynapse) -> float:
- return await self.priority(synapse)
+ router.add_api_route(
+ "/detect_image",
+ self.detect_image,
+ dependencies=[Depends(self.determine_epistula_version_and_verify)],
+ methods=["POST"],
+ )
+ router.add_api_route(
+ "/detect_video",
+ self.detect_video,
+ dependencies=[Depends(self.determine_epistula_version_and_verify)],
+ methods=["POST"],
+ )
+ app.include_router(router)
+ fast_config = uvicorn.Config(
+ app,
+ host="0.0.0.0",
+ port=self.config.axon.port,
+ log_level="info",
+ loop="asyncio",
+ )
+ self.fast_api = FastAPIThreadedServer(config=fast_config)
+ self.fast_api.start()
- async def priority_video(self, synapse: VideoSynapse) -> float:
- return await self.priority(synapse)
+ bt.logging.info(f"Miner {self.uid} starting at block: {self.subtensor.block}")
- def save_state(self):
- pass
+ # This loop maintains the miner's operations until intentionally stopped.
+ try:
+ while not self.exit_context.isExiting:
+ time.sleep(1)
+ # Make sure our substrate thread is alive
+ if not self.substrate_thread.is_alive():
+ bt.logging.info("Restarting substrate interface due to killed node")
+ self.substrate = SubstrateInterface(
+ ss58_format=SS58_FORMAT,
+ use_remote_preset=True,
+ url=self.config.subtensor.chain_endpoint,
+ type_registry=TYPE_REGISTRY,
+ )
+ self.substrate_thread = run_block_callback_thread(
+ self.substrate, self.run_callbacks
+ )
+ except Exception as e:
+ bt.logging.error(str(e))
+ bt.logging.error(traceback.format_exc())
+ finally:
+ self.shutdown()
-# This is the main function, which runs the miner.
-if __name__ == "__main__":
- import warnings
- warnings.filterwarnings("ignore")
- with Miner() as miner:
- while True:
- bt.logging.info(f"Miner running | uid {miner.uid} | {time.time()}")
- time.sleep(5)
+def ping():
+ return 200
+
+
+if __name__ == "__main__":
+ try:
+ miner = Miner()
+ miner.run()
+ except Exception as e:
+ bt.logging.error(str(e))
+ bt.logging.error(traceback.format_exc())
+ exit()
diff --git a/neurons/proxy.py b/neurons/proxy.py
new file mode 100644
index 00000000..3c53b701
--- /dev/null
+++ b/neurons/proxy.py
@@ -0,0 +1,647 @@
+import asyncio
+import base64
+import io
+import uuid
+import socket
+import tempfile
+import time
+import traceback
+from typing import Dict, List, Optional, Any, Union, Tuple
+
+import aiohttp
+import bittensor as bt
+import cv2
+import httpx
+import numpy as np
+import uvicorn
+from cryptography.exceptions import InvalidSignature
+from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
+from fastapi import (
+ FastAPI,
+ Request,
+ HTTPException,
+ Depends,
+ status,
+ APIRouter,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.security import APIKeyHeader
+from PIL import Image
+from bittensor.core.axon import FastAPIThreadedServer
+
+from bitmind.encoding import media_to_bytes
+from bitmind.epistula import query_miner
+from bitmind.metagraph import get_miner_uids
+from bitmind.transforms import get_base_transforms
+from bitmind.types import Modality, NeuronType
+from neurons.base import BaseNeuron
+
+AUTH_HEADER = APIKeyHeader(name="Authorization")
+DEFAULT_TIMEOUT = 9
+DEFAULT_SAMPLE_SIZE = 50
+
+
+class MediaProcessor:
+ def __init__(self, target_size: tuple = (256, 256)):
+ self.target_size = target_size
+
+ def process_image(self, b64_image: str) -> np.ndarray:
+ """
+ Decode base64 image and preprocess
+
+ Args:
+ b64_image: Base64 encoded image string
+
+ Returns:
+ Processed image as numpy array
+ """
+ try:
+ image_bytes = base64.b64decode(b64_image)
+ image = Image.open(io.BytesIO(image_bytes))
+ transformed_image = get_base_transforms(self.target_size)(np.array(image))
+ image_bytes, content_type = media_to_bytes(transformed_image)
+ return image_bytes, content_type
+
+ except Exception as e:
+ bt.logging.error(f"Error processing image: {e}")
+ raise ValueError(f"Failed to process image: {str(e)}")
+
+ def process_video(self, video_data: bytes) -> np.ndarray:
+ """
+ Process raw video bytes into frames and preprocess
+
+ Args:
+ video_data: Raw video bytes
+
+ Returns:
+ Processed video frames as numpy array
+ """
+ bt.logging.dbug(f"Starting video processing with {len(video_data)} bytes")
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=True) as temp_file:
+ temp_file.write(video_data)
+ temp_file.flush()
+
+ cap = cv2.VideoCapture(temp_file.name)
+ if not cap.isOpened():
+ bt.logging.error("Failed to open video stream")
+ raise ValueError("Failed to open video stream")
+
+ try:
+ frames = []
+ while True:
+ success, frame = cap.read()
+ if not success:
+ break
+
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames.append(rgb_frame)
+
+ bt.logging.info(f"Extracted {len(frames)} frames")
+
+ if not frames:
+ bt.logging.error("No frames extracted from video")
+ raise ValueError("No frames extracted from video")
+
+ transformed_frames = get_base_transforms(self.target_size)(
+ np.stack(frames)
+ )
+ video_bytes, content_type = media_to_bytes(transformed_frames)
+ return video_bytes, content_type
+
+ except Exception as e:
+ bt.logging.error(f"Error in video processing: {str(e)}")
+ raise
+ finally:
+ cap.release()
+
+
+class ValidatorProxy(BaseNeuron):
+ """
+ Proxy server that handles requests from external applications and forwards them to miners.
+ Uses FastAPIThreadedServer for improved concurrency.
+ """
+
+ neuron_type = NeuronType.VALIDATOR_PROXY
+
+ def __init__(self, config=None):
+ """
+ Initialize the proxy server.
+
+ Args:
+ validator: The validator instance that manages miners and scoring
+ """
+ super().__init__(config=config)
+
+ if not (
+ hasattr(self.config, "proxy")
+ and hasattr(self.config.proxy, "client_url")
+ and hasattr(self.config.proxy, "port")
+ ):
+ raise ValueError(
+ "Missing proxy configuration - cannot initialize ValidatorProxy"
+ )
+
+ self.block_callbacks.append(self.log_on_block)
+
+ self.port = self.config.proxy.port
+ self.external_port = self.config.proxy.external_port
+ self.host = self.config.proxy.host
+ self.media_processor = MediaProcessor()
+ self.auth_verifier = self._setup_auth()
+
+ self.session = None
+ self.max_connections = 50
+ self.connector = None
+ self.fast_api = None
+
+ self.request_times = {
+ "image": [],
+ "video": [],
+ }
+ self.max_request_history = 100
+
+ self.setup_app()
+
+ bt.logging.info(f"Initialized proxy server on {self.host}:{self.port}")
+
+ def _setup_auth(self) -> callable:
+ try:
+ with httpx.Client() as client:
+ response = client.post(
+ f"{self.config.proxy.client_url}/get-credentials",
+ json={"postfix": f":{self.external_port}", "uid": self.uid},
+ timeout=DEFAULT_TIMEOUT,
+ )
+ creds = response.json()
+
+ signature = base64.b64decode(creds["signature"])
+ message = creds["message"]
+
+ def verify(key_bytes: bytes) -> bool:
+ try:
+ key = Ed25519PublicKey.from_public_bytes(key_bytes)
+ key.verify(signature, message.encode())
+ return True
+ except InvalidSignature:
+ return False
+
+ bt.logging.info("Authentication setup successful")
+ return verify
+
+ except Exception as e:
+ bt.logging.error(f"Error setting up authentication: {e}")
+ bt.logging.error("Authentication will be disabled")
+ return None
+
+ async def verify_auth(self, auth: str = Depends(AUTH_HEADER)) -> None:
+ if not self.auth_verifier:
+ return
+
+ try:
+ key_bytes = base64.b64decode(auth)
+ if not self.auth_verifier(key_bytes):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid authentication token",
+ )
+ except Exception as e:
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
+
+ async def log_on_block(self, block):
+ """
+ Log avg request times
+
+ Args:
+ block: Current block number
+ """
+ log_items = [f"Forward Block: {self.subtensor.block}"]
+ if self.request_times.get("image"):
+ avg_image_time = sum(self.request_times["image"]) / len(
+ self.request_times["image"]
+ )
+ log_items.append(f"Avg image request: {avg_image_time:.2f}s")
+
+ if self.request_times.get("video"):
+ avg_video_time = sum(self.request_times["video"]) / len(
+ self.request_times["video"]
+ )
+ log_items.append(f"Avg video request: {avg_video_time:.2f}s")
+
+ bt.logging.info(" | ".join(log_items))
+
+ def setup_app(self):
+ app = FastAPI(title="BitMind Proxy Server")
+ router = APIRouter()
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ router.add_api_route(
+ "/forward_image",
+ self.handle_image_request,
+ methods=["POST"],
+ dependencies=[Depends(self.verify_auth)],
+ )
+ router.add_api_route(
+ "/forward_video",
+ self.handle_video_request,
+ methods=["POST"],
+ dependencies=[Depends(self.verify_auth)],
+ )
+ router.add_api_route(
+ "/healthcheck",
+ self.healthcheck,
+ methods=["GET"],
+ dependencies=[Depends(self.verify_auth)],
+ )
+
+ app.include_router(router)
+
+ fast_config = uvicorn.Config(
+ app,
+ host=self.host,
+ port=self.port,
+ log_level="info",
+ loop="asyncio",
+ workers=9,
+ )
+ self.fast_api = FastAPIThreadedServer(config=fast_config)
+
+ async def handle_image_request(self, request: Request) -> Dict[str, Any]:
+ """
+ Handle image processing requests.
+
+ Args:
+ request: FastAPI request object with JSON body containing base64 image
+
+ Returns:
+ Dictionary with prediction results
+ """
+ start_time = time.time()
+ request_id = str(uuid.uuid4())[:8]
+ bt.logging.debug(f"[{request_id}] Starting image request processing")
+
+ try:
+ payload = await request.json()
+ if "image" not in payload:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Missing 'image' field in request body",
+ )
+ b64_image = payload["image"]
+
+ proc_start = time.time()
+ media_bytes, content_type = await asyncio.to_thread(
+ self.media_processor.process_image, b64_image
+ )
+ bt.logging.trace(
+ f"[{request_id}] Image processed in {time.time() - proc_start:.2f}s"
+ )
+
+ query_start = time.time()
+ results = await self.query_miners(
+ media_bytes=media_bytes,
+ content_type=content_type,
+ modality=Modality.IMAGE,
+ request_id=request_id,
+ )
+ bt.logging.trace(
+ f"[{request_id}] Miners queried in {time.time() - query_start:.2f}s"
+ )
+
+ predictions, uids = self.aggregate_responses(results)
+ response = {
+ "preds": [float(p) for p in predictions],
+ "fqdn": socket.getfqdn(),
+ }
+
+ # Add rich data if requested
+ if payload.get("rich", "").lower() == "true":
+ response.update(self._get_rich_data(uids))
+
+ total_time = time.time() - start_time
+ bt.logging.debug(
+ f"[{request_id}] Image request processed in {total_time:.2f}s"
+ )
+
+ if len(self.request_times["image"]) >= self.max_request_history:
+ self.request_times["image"].pop(0)
+ self.request_times["image"].append(total_time)
+
+ return response
+
+ except Exception as e:
+ bt.logging.error(f"[{request_id}] Error processing image request: {e}")
+ bt.logging.error(traceback.format_exc())
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Error processing request: {str(e)}",
+ )
+
+ async def handle_video_request(self, request: Request) -> Dict[str, Any]:
+ """
+ Handle video processing requests.
+
+ Args:
+ request: FastAPI request object with form data containing video file
+
+ Returns:
+ Dictionary with prediction results
+ """
+ start_time = time.time()
+ request_id = str(uuid.uuid4())[:8]
+ bt.logging.debug(f"[{request_id}] Starting video request processing")
+
+ try:
+ form = await request.form()
+ if "video" not in form:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Missing 'video' field in form data",
+ )
+
+ video_file = form["video"]
+ video_data = await video_file.read()
+
+ if not video_data:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Empty video file"
+ )
+
+ rich_param = form.get("rich", "").lower()
+
+ proc_start = time.time()
+ media_bytes, content_type = await asyncio.to_thread(
+ self.media_processor.process_video, video_data
+ )
+ bt.logging.trace(
+ f"[{request_id}] Video processed in {time.time() - proc_start:.2f}s"
+ )
+
+ query_start = time.time()
+ results = await self.query_miners(
+ media_bytes=media_bytes,
+ content_type=content_type,
+ modality=Modality.VIDEO,
+ request_id=request_id,
+ )
+ bt.logging.trace(
+ f"[{request_id}] Miners queried in {time.time() - query_start:.2f}s"
+ )
+
+ predictions, uids = self.aggregate_responses(results)
+ response = {
+ "preds": [float(p) for p in predictions],
+ "fqdn": socket.getfqdn(),
+ }
+
+ # Add rich data if requested
+ if rich_param == "true":
+ response.update(self._get_rich_data(uids))
+
+ total_time = time.time() - start_time
+ bt.logging.debug(
+ f"[{request_id}] Video request processed in {total_time:.2f}s"
+ )
+
+ if len(self.request_times["video"]) >= self.max_request_history:
+ self.request_times["video"].pop(0)
+ self.request_times["video"].append(total_time)
+ return response
+
+ except Exception as e:
+ bt.logging.error(f"Error processing video request: {e}")
+ bt.logging.error(traceback.format_exc())
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Error processing request: {str(e)}",
+ )
+
+ async def healthcheck(self, request: Request) -> Dict[str, str]:
+ """Health check endpoint."""
+ return {"status": "healthy"}
+
+ async def query_miners(
+ self,
+ media_bytes: bytes,
+ content_type: str,
+ modality: Modality,
+ num_miners: int = DEFAULT_SAMPLE_SIZE,
+ request_id: Optional[str] = None,
+ ) -> List[Dict[str, Any]]:
+ """
+ Query a set of miners with the given media.
+
+ Args:
+ media_bytes: Encoded media bytes
+ content_type: Media content type
+ modality: Media modality (image or video)
+ num_miners: Number of miners to sample
+
+ Returns:
+ List of miner responses
+ """
+ query_start = time.time()
+
+ miner_uids = get_miner_uids(
+ self.metagraph,
+ self.uid,
+ self.config.vpermit_tao_limit,
+ )
+ miner_uids = np.random.choice(
+ miner_uids, size=min(num_miners, len(miner_uids)), replace=False
+ )
+
+ total_timeout = self.config.neuron.miner_total_timeout
+ connect_timeout = self.config.neuron.miner_connect_timeout
+ sock_timeout = self.config.neuron.miner_sock_connect_timeout
+ read_timeout = self.config.neuron.miner_total_timeout
+
+ async with aiohttp.ClientSession(
+ timeout=aiohttp.ClientTimeout(
+ total=total_timeout,
+ connect=connect_timeout,
+ sock_connect=sock_timeout,
+ sock_read=read_timeout,
+ ),
+ connector=aiohttp.TCPConnector(limit=100),
+ ) as session:
+ challenge_tasks = []
+ for uid in miner_uids:
+ axon_info = self.metagraph.axons[uid]
+ challenge_tasks.append(
+ query_miner(
+ uid,
+ media_bytes,
+ content_type,
+ modality,
+ axon_info,
+ session,
+ self.wallet.hotkey,
+ total_timeout,
+ connect_timeout,
+ sock_timeout,
+ )
+ )
+
+ try:
+ responses = []
+ for future in asyncio.as_completed(
+ challenge_tasks, timeout=total_timeout
+ ):
+ try:
+ response = await future
+ responses.append(response)
+ except Exception as e:
+ bt.logging.warning(f"Miner query error: {str(e)}")
+ bt.logging.error(traceback.format_exc())
+
+ filtered_responses = []
+ for i, response in enumerate(responses):
+ if isinstance(response, Exception):
+ bt.logging.warning(
+ f"Miner {miner_uids[i]} failed: {str(response)}"
+ )
+ filtered_responses.append(
+ {"uid": miner_uids[i], "error": True, "prediction": None}
+ )
+ else:
+ filtered_responses.append(response)
+
+ responses = filtered_responses
+
+ except asyncio.TimeoutError:
+ bt.logging.warning(
+ f"Timed out waiting for miner responses after {total_timeout}s"
+ )
+ responses = [
+ {"uid": uid, "error": True, "prediction": None}
+ for uid in miner_uids
+ ]
+
+ query_time = time.time() - query_start
+ bt.logging.info(
+ f"Received {len([r for r in responses if not r.get('error', False)])} valid miner responses for {modality} request in {query_time:.2f}s"
+ )
+ return responses
+
+ def _get_rich_data(self, uids: List[int]) -> Dict[str, List]:
+ """Get additional miner metadata."""
+ return {
+ "uids": [int(uid) for uid in uids],
+ "ranks": [float(self.metagraph.R[uid]) for uid in uids],
+ "incentives": [float(self.metagraph.I[uid]) for uid in uids],
+ "emissions": [float(self.metagraph.E[uid]) for uid in uids],
+ "hotkeys": [str(self.metagraph.hotkeys[uid]) for uid in uids],
+ "coldkeys": [str(self.metagraph.coldkeys[uid]) for uid in uids],
+ }
+
+ def aggregate_responses(
+ self, results: List[Dict[str, Any]]
+ ) -> Tuple[np.ndarray, List[int]]:
+ """
+ Aggregate miner responses into a final result.
+
+ Args:
+ results: List of miner responses
+
+ Returns:
+ Tuple of (aggregated predictions, responding miner UIDs)
+ """
+ valid_responses = [
+ r for r in results if not r["error"] and r["prediction"] is not None
+ ]
+
+ if not valid_responses:
+ bt.logging.warning("No valid responses received from miners")
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
+ detail="No valid predictions received",
+ )
+
+ predictions = np.array([r["prediction"] for r in valid_responses])
+ uids = [r["uid"] for r in valid_responses]
+
+ predictions = [p[1] + p[2] for p in predictions]
+ return predictions, uids
+
+ async def start(self):
+ """Start the FastAPI threaded server and initialize connection pooling."""
+ bt.logging.info(f"Starting proxy server on {self.host}:{self.port}")
+
+ if self.connector is None:
+ self.connector = aiohttp.TCPConnector(
+ limit=self.max_connections,
+ limit_per_host=5,
+ enable_cleanup_closed=True,
+ force_close=False,
+ ttl_dns_cache=300,
+ )
+
+ if self.session is None:
+ self.session = aiohttp.ClientSession(
+ connector=self.connector,
+ timeout=aiohttp.ClientTimeout(
+ total=self.config.neuron.miner_total_timeout,
+ connect=self.config.neuron.miner_connect_timeout,
+ sock_connect=self.config.neuron.miner_sock_connect_timeout,
+ sock_read=self.config.neuron.miner_total_timeout,
+ ),
+ )
+
+ if self.fast_api:
+ self.fast_api.start()
+ else:
+ bt.logging.error("FastAPI server not initialized")
+
+ async def run(self):
+ await self.start()
+ while not self.exit_context.isExiting:
+ # Make sure our substrate thread is alive
+ if not self.substrate_thread.is_alive():
+ bt.logging.info("Restarting substrate interface due to killed node")
+ self.substrate = SubstrateInterface(
+ ss58_format=SS58_FORMAT,
+ use_remote_preset=True,
+ url=self.config.subtensor.chain_endpoint,
+ type_registry=TYPE_REGISTRY,
+ )
+ self.substrate_thread = run_block_callback_thread(
+ self.substrate, self.run_callbacks
+ )
+
+ await asyncio.sleep(1)
+
+ await self.shutdown()
+
+ async def shutdown(self):
+ """Shutdown the server and clean up resources."""
+ bt.logging.info("Shutting down proxy server")
+
+ if self.session:
+ await self.session.close()
+ self.session = None
+
+ if self.connector:
+ await self.connector.close()
+ self.connector = None
+
+ if self.fast_api:
+ self.fast_api.stop()
+ self.fast_api = None
+
+
+if __name__ == "__main__":
+ try:
+ proxy = ValidatorProxy()
+ asyncio.run(proxy.run())
+ except KeyboardInterrupt:
+ bt.logging.info("Proxy interrupted by KeyboardInterrupt, shutting down")
+ except Exception as e:
+ bt.logging.error(f"Unhandled exception: {e}")
+ bt.logging.error(traceback.format_exc())
diff --git a/neurons/unit_tests/sample_image.jpg b/neurons/unit_tests/sample_image.jpg
deleted file mode 100644
index fd2fff10..00000000
Binary files a/neurons/unit_tests/sample_image.jpg and /dev/null differ
diff --git a/neurons/unit_tests/test_miner.py b/neurons/unit_tests/test_miner.py
deleted file mode 100644
index cddd25b1..00000000
--- a/neurons/unit_tests/test_miner.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import unittest
-import torch
-import numpy as np
-from PIL import Image
-import os
-import sys
-import base64
-import io
-import asyncio
-
-# Miner class located in the parent directory
-directory = os.path.dirname(os.path.abspath(__file__))
-parent_directory = os.path.dirname(directory)
-sys.path.append(parent_directory)
-
-from miner import Miner
-from bitmind.base.miner import BaseMinerNeuron
-from bitmind.protocol import ImageSynapse
-
-class TestMiner(unittest.TestCase):
-
- def setUp(self):
- """Set up the necessary components for testing the Miner."""
-
- self.miner = Miner.__new__(Miner) # Create an instance of the Miner class without initialization
- self.miner.config = self.miner.config()
- self.script_dir = os.path.dirname(__file__)
- self.image_path = os.path.join(self.script_dir, 'sample_image.jpg')
-
- # Load a sample image and convert it to base64 for the synapse object
- with open(self.image_path, "rb") as img_file:
- self.image = Image.open(self.image_path)
- self.image_bytes = img_file.read()
- self.image_base64 = base64.b64encode(self.image_bytes).decode('utf-8')
-
- def test_init_detector(self):
- """Test if the models load properly with the given weight paths."""
- self.miner.load_detector()
- print(f"Detector: {self.miner.deepfake_detector}, Type:{type(self.miner.deepfake_detector)}")
- self.assertIsNotNone(self.miner.deepfake_detector, "Detector should not be None")
-
- def test_deepfake_detector(self):
- """Test the deepfake detection functionality."""
- # Test the deepfake detector directly
- self.miner.load_detector()
- prediction = self.miner.deepfake_detector(self.image)
- print(f"Prediction: {prediction}")
-
- # Check that the prediction is not None and within valid range (assuming it's a probability)
- self.assertIsNotNone(prediction, "Prediction should not be None")
- self.assertIsInstance(prediction, np.ndarray, "Prediction should be a numpy array")
- self.assertTrue(0 <= prediction <= 1, "Prediction should be between 0 and 1")
-
- def test_forward_synapse(self):
- """Test the forward method in the Miner class using a mock ImageSynapse."""
- # Create a mock synapse object with base64 encoded image
- self.miner.load_detector()
- synapse = ImageSynapse(image=self.image_base64)
-
- # Run the detector through the synapse and verify the output prediction
- pred = asyncio.run(self.miner.forward(synapse)).prediction
- print(f"Synapse prediction: {pred}")
-
- self.assertIsNotNone(pred, "Prediction in the synapse should not be None")
- self.assertIsInstance(pred, float, "Synapse prediction should be a f")
- self.assertTrue(0 <= pred <= 1, "Synapse prediction should be between 0 and 1")
-
-if __name__ == '__main__':
- unittest.main()
\ No newline at end of file
diff --git a/neurons/validator.py b/neurons/validator.py
index 2fa75e43..2d312771 100644
--- a/neurons/validator.py
+++ b/neurons/validator.py
@@ -1,161 +1,702 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# developer: dubm
-# Copyright © 2023 Bitmind
+import asyncio
+import json
+import os
+import shutil
+import sys
+import threading
+import time
+import traceback
+from threading import Thread
+from time import sleep
+from typing import Any, Dict, Optional
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+import aiohttp
+import bittensor as bt
+import numpy as np
+from bittensor.core.settings import SS58_FORMAT, TYPE_REGISTRY
+from substrateinterface import SubstrateInterface
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
+from bitmind import __spec_version__ as spec_version
+from bitmind.autoupdater import autoupdate
+from bitmind.cache import CacheSystem
+from bitmind.config import MAINNET_UID
+from bitmind.encoding import media_to_bytes
+from bitmind.epistula import query_miner
+from bitmind.metagraph import (
+ create_set_weights,
+ get_miner_uids,
+ run_block_callback_thread,
+)
+from bitmind.scoring import EvalEngine
+from bitmind.transforms import apply_random_augmentations
+from bitmind.types import (
+ MediaType,
+ Modality,
+ NeuronType,
+)
+from bitmind.utils import on_block_interval, print_info
+from bitmind.wandb_utils import WandbLogger
+from neurons.base import BaseNeuron
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-import os
-os.environ["CUDA_VISIBLE_DEVICES"] = ""
+class Validator(BaseNeuron):
+ neuron_type = NeuronType.VALIDATOR
+ cache_system: Optional[CacheSystem] = None
+ heartbeat_thread: Thread
+ lock_waiting = False
+ lock_halt = False
+ step = 0
+ initialization_complete: bool = False
-import bittensor as bt
-import yaml
-import wandb
-import time
+ def __init__(self, config=None, run_init=True):
+ super().__init__(config=config)
-from neurons.validator_proxy import ValidatorProxy
-from bitmind.validator.forward import forward
-from bitmind.validator.cache import VideoCache, ImageCache
-from bitmind.base.validator import BaseValidatorNeuron
-from bitmind.validator.config import (
- MAINNET_UID,
- MAINNET_WANDB_PROJECT,
- TESTNET_WANDB_PROJECT,
- WANDB_ENTITY,
- REAL_VIDEO_CACHE_DIR,
- REAL_IMAGE_CACHE_DIR,
- SYNTH_VIDEO_CACHE_DIR,
- SYNTH_IMAGE_CACHE_DIR,
- SEMISYNTH_VIDEO_CACHE_DIR,
- SEMISYNTH_IMAGE_CACHE_DIR,
- VALIDATOR_INFO_PATH
-)
+ ## Typesafety
+ self.set_weights = create_set_weights(spec_version, self.config.netuid)
-import bitmind
-
-
-class Validator(BaseValidatorNeuron):
- """
- The BitMind Validator's `forward` function sends single-image challenges to miners every 30 seconds, where each
- image has a 50/50 chance of being real or fake. In service of this task, the Validator class has two key members -
- self.real_image_datasets and self.synthetic_image_generator. The former is a list of ImageDataset objects, which
- contain real images. The latter is an ML pipeline that combines an LLM for prompt generation and diffusion
- models that ingest prompts output by the LLM to produce synthetic images.
-
- The BitMind Validator also encapsuluates a ValidatorProxy, which is used to service organic requests from
- our consumer-facing application. If you wish to participate in this system, run your validator with the
- --proxy.port argument set to an exposed port on your machine.
- """
- def __init__(self, config=None):
- super(Validator, self).__init__(config=config)
- bt.logging.info("Starting validator with loaded scores:")
- bt.logging.info(self.scores)
-
- self.last_responding_miner_uids = []
- self.validator_proxy = ValidatorProxy(self)
-
- self.image_cache = {
- 'real': ImageCache(REAL_IMAGE_CACHE_DIR),
- 'synthetic': ImageCache(SYNTH_IMAGE_CACHE_DIR),
- 'semisynthetic': ImageCache(SEMISYNTH_IMAGE_CACHE_DIR),
- }
- self.video_cache = {
- 'real': VideoCache(REAL_VIDEO_CACHE_DIR),
- 'synthetic': VideoCache(SYNTH_VIDEO_CACHE_DIR),
- 'semisynthetic': VideoCache(SEMISYNTH_VIDEO_CACHE_DIR),
- }
- self.media_cache = {
- 'image': self.image_cache,
- 'video': self.video_cache
- }
+ ## CHECK IF REGG'D
+ if (
+ not self.metagraph.validator_permit[self.uid]
+ and self.config.netuid == MAINNET_UID
+ ):
+ bt.logging.error("Validator does not have vpermit")
+ sys.exit(1)
+ if run_init:
+ self.init()
- self.init_wandb()
- self.store_vali_info()
+ def init(self):
+ assert self.config.netuid
+ assert self.config.vpermit_tao_limit
+ assert self.config.subtensor
- async def forward(self):
- """
- Validator forward pass. Consists of:
- - Generating the query
- - Querying the miners
- - Getting the responses
- - Rewarding the miners
- - Updating the scores
- """
- return await forward(self)
+ self._validate_challenge_probs()
+
+ if not self.config.wandb_off:
+ self.wandb_logger = WandbLogger(self.config, self.uid, self.wallet.hotkey)
+
+ bt.logging.info(self.config)
+ bt.logging.info(f"Last updated at block {self.metagraph.last_update[self.uid]}")
+
+ self.eval_engine = EvalEngine(self.metagraph, self.config)
+
+ ## REGISTER BLOCK CALLBACKS
+ self.block_callbacks.extend(
+ [
+ self.log_on_block,
+ self.set_weights_on_interval,
+ self.send_challenge_to_miners_on_interval,
+ self.update_compressed_cache_on_interval,
+ self.update_media_cache_on_interval,
+ self.start_new_wanbd_run_on_interval
+ ]
+ )
+
+ # SETUP HEARTBEAT THREAD
+ if self.config.neuron.heartbeat:
+ self.heartbeat_thread = Thread(name="heartbeat", target=self.heartbeat)
+ self.heartbeat_thread.start()
+
+ ## DONE
+ bt.logging.info(
+ "\N{GRINNING FACE WITH SMILING EYES}", "Successfully Initialized!"
+ )
+
+ async def run(self):
+ assert self.config.subtensor
+ assert self.config.neuron
+ assert self.config.vpermit_tao_limit
+ bt.logging.info(
+ f"Running validator {self.uid} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}"
+ )
+
+ await self.load_state()
+
+ self.cache_system = CacheSystem()
+ await self.cache_system.initialize(
+ self.config.cache.base_dir,
+ self.config.cache.max_compressed_gb,
+ self.config.cache.max_media_gb,
+ self.config.cache.media_files_per_source,
+ )
+
+ self.initialization_complete = True
+ bt.logging.success(
+ f"Initialization Complete. Validator starting at block: {self.subtensor.block}"
+ )
+
+ while not self.exit_context.isExiting:
+ self.step += 1
+ if self.config.autoupdate and (self.step == 0 or not self.step % 30):
+ bt.logging.debug("Checking autoupdate")
+ autoupdate(branch="v3")
+
+ # Make sure our substrate thread is alive
+ if not self.substrate_thread.is_alive():
+ bt.logging.info("Restarting substrate interface due to killed node")
+ self.substrate = SubstrateInterface(
+ ss58_format=SS58_FORMAT,
+ use_remote_preset=True,
+ url=self.config.subtensor.chain_endpoint,
+ type_registry=TYPE_REGISTRY,
+ )
+ self.substrate_thread = run_block_callback_thread(
+ self.substrate, self.run_callbacks
+ )
+
+ if self.lock_halt:
+ self.lock_waiting = True
+ while self.lock_halt:
+ bt.logging.info("Waiting for lock to release")
+ sleep(self.config.neuron.lock_sleep_seconds)
+ self.lock_waiting = False
+ await asyncio.sleep(1)
+
+ await self.shutdown()
+
+ def heartbeat(self):
+ bt.logging.info("Starting Heartbeat")
+ last_step = self.step
+ stuck_count = 0
+ while True:
+ while self.lock_halt:
+ sleep(self.config.neuron.lock_sleep_seconds)
+ sleep(self.config.neuron.heartbeat_interval_seconds)
+ if last_step == self.step:
+ stuck_count += 1
+ if last_step != self.step:
+ stuck_count = 0
+ if stuck_count >= self.config.neuron.max_stuck_count:
+ bt.logging.error(
+ "Heartbeat detecting main process hang, attempting restart"
+ )
+ autoupdate(force=True)
+ sys.exit(0)
+ last_step = self.step
+ bt.logging.info("Heartbeat")
+
+ @on_block_interval("challenge_interval")
+ async def send_challenge_to_miners_on_interval(self, block):
+ assert self.config.vpermit_tao_limit
- def init_wandb(self):
- if self.config.wandb.off:
+ miner_uids = get_miner_uids(
+ self.metagraph, self.uid, self.config.vpermit_tao_limit
+ )
+ if len(miner_uids) > self.config.neuron.sample_size:
+ miner_uids = np.random.choice(
+ miner_uids, size=self.config.neuron.sample_size, replace=False
+ ).tolist()
+
+ media_sample = await self._sample_media()
+ if not media_sample:
+ bt.logging.warning("Waiting for cache to populate. Challenge skipped.")
return
- run_name = f'validator-{self.uid}-{bitmind.__version__}'
- self.config.run_name = run_name
- self.config.uid = self.uid
- self.config.hotkey = self.wallet.hotkey.ss58_address
- self.config.version = bitmind.__version__
- self.config.type = self.neuron_type
+ modality = media_sample["modality"]
+ media = media_sample[modality]
- wandb_project = TESTNET_WANDB_PROJECT
- if self.config.netuid == MAINNET_UID:
- wandb_project = MAINNET_WANDB_PROJECT
+ media_bytes, content_type = media_to_bytes(
+ media, fps=media_sample.get("fps", None)
+ )
- # Initialize the wandb run for the single project
- bt.logging.info(f"Initializing W&B run for '{WANDB_ENTITY}/{wandb_project}'")
- try:
- run = wandb.init(
- name=run_name,
- project=wandb_project,
- entity=WANDB_ENTITY,
- config=self.config,
- dir=self.config.full_path,
- reinit=True
+ bt.logging.info(f"---------- Starting Challenge at Block {block} ----------")
+ bt.logging.info(f"Sampled from {modality} cache")
+
+ challenge_tasks = []
+ challenge_results = []
+ async with aiohttp.ClientSession() as session:
+ for uid in miner_uids:
+ axon_info = self.metagraph.axons[uid]
+ challenge_tasks.append(
+ query_miner(
+ uid,
+ media_bytes,
+ content_type,
+ modality,
+ axon_info,
+ session,
+ self.wallet.hotkey,
+ self.config.neuron.miner_total_timeout,
+ self.config.neuron.miner_connect_timeout,
+ self.config.neuron.miner_sock_connect_timeout,
+ )
+ )
+ if len(challenge_tasks) != 0:
+ responses = await asyncio.gather(*challenge_tasks)
+ challenge_results.extend(responses)
+ challenge_tasks = []
+
+ valid_responses = [r for r in challenge_results if not r["error"]]
+
+ n_valid = len(valid_responses)
+ n_failures = len(challenge_results) - len(valid_responses)
+ bt.logging.info(
+ f"Received {n_valid} valid miner responses. ({n_failures} others failed.)"
+ )
+
+ bt.logging.info(f"Scoring {modality} challenge")
+ rewards = {}
+ if n_valid > 0:
+ bt.logging.success(valid_responses)
+ rewards = self.eval_engine.score_challenge(
+ valid_responses,
+ media_sample["label"],
+ modality,
+ )
+
+ self.log_challenge_results(media_sample, challenge_results, rewards)
+
+ await self.save_state()
+ bt.logging.success(f"---------- Challenge Complete ----------")
+
+ @on_block_interval("compressed_cache_update_interval")
+ async def update_compressed_cache_on_interval(self, block):
+ if (
+ hasattr(self, "_compressed_cache_thread")
+ and self._compressed_cache_thread.is_alive()
+ ):
+ bt.logging.warning(
+ f"Previous compressed cache update still running at block {block}, skipping this update"
+ )
+ return
+
+ def update_compressed_cache():
+ """Thread function to update compressed cache."""
+ try:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.run_until_complete(self.cache_system.update_compressed_caches())
+ bt.logging.info(f"Compressed cache update complete")
+ except Exception as e:
+ bt.logging.error(f"Error updating compressed caches: {e}")
+ bt.logging.error(traceback.format_exc())
+ finally:
+ loop.close()
+
+ bt.logging.info(f"Updating compressed caches at block {block}")
+ self._compressed_cache_thread = threading.Thread(
+ target=update_compressed_cache, daemon=True
+ )
+ self._compressed_cache_thread.start()
+
+ @on_block_interval("media_cache_update_interval")
+ async def update_media_cache_on_interval(self, block):
+ if hasattr(self, "_media_cache_thread") and self._media_cache_thread.is_alive():
+ bt.logging.warning(
+ f"Previous media cache update still running at block {block}, skipping this update"
)
- except wandb.UsageError as e:
- bt.logging.warning(e)
- bt.logging.warning("Did you run wandb login?")
return
- # Sign the run to ensure it's from the correct hotkey
- signature = self.wallet.hotkey.sign(run.id.encode()).hex()
- self.config.signature = signature
- wandb.config.update(self.config, allow_val_change=True)
+ def update_media_cache():
+ """Thread function to update media cache."""
+ try:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ loop.run_until_complete(self.cache_system.update_media_caches())
+ bt.logging.info(f"Media cache update complete")
+ except Exception as e:
+ bt.logging.error(f"Error updating media caches: {e}")
+ bt.logging.error(traceback.format_exc())
+ finally:
+ loop.close()
- bt.logging.success(f"Started wandb run {run_name}")
+ bt.logging.info(f"Updating media caches at block {block}")
+ self._media_cache_thread = threading.Thread(
+ target=update_media_cache, daemon=True
+ )
+ self._media_cache_thread.start()
- def store_vali_info(self):
+ @on_block_interval("epoch_length")
+ async def set_weights_on_interval(self, block):
+ try:
+ bt.logging.info(
+ f"Waiting to safely set weights at block {block} (epoch length = {self.config.epoch_length})"
+ )
+ self.lock_halt = True
+ while not self.lock_waiting and block != 0:
+ sleep(self.config.neuron.lock_sleep_seconds)
+
+ bt.logging.info(f"Setting weights at block {block}")
+ self.subtensor = bt.subtensor(
+ config=self.config, network=self.config.subtensor.chain_endpoint
+ )
+ weights = self.eval_engine.get_weights()
+ uids = list(range(self.metagraph.n))
+
+ self.set_weights(
+ self.wallet, self.metagraph, self.subtensor, (uids, weights)
+ )
+ bt.logging.success("Weights set successfully")
+
+ except Exception as e:
+ bt.logging.error(f"Error in set_weights_on_interval: {e}")
+ bt.logging.error(traceback.format_exc())
+ finally:
+ self.lock_halt = False
+
+ @on_block_interval("wandb_restart_interval")
+ async def start_new_wanbd_run_on_interval(self, block):
+ try:
+ self.wandb_logger.start_new_run()
+ except Exception as e:
+ bt.logging.error(f"Not able to start new W&B run: {e}")
+
+ async def _sample_media(self) -> Optional[Dict[str, Any]]:
"""
- Stores the uid, hotkey and netuid of the currently running vali instance.
- The SyntheticDataGenerator process reads this to name its w&b run
+ Sample a media item from the cache system.
+
+ Returns:
+ Dictionary with media item details or None if sampling fails
"""
- validator_info = {
- 'uid': self.uid,
- 'hotkey': self.wallet.hotkey.ss58_address,
- 'netuid': self.config.netuid,
- 'full_path': self.config.neuron.full_path
+ if not self.cache_system:
+ return None
+
+ modality, media_type, multi_video = self.determine_challenge_type()
+
+ kwargs = {}
+ if modality == Modality.VIDEO:
+ kwargs = {
+ "min_duration": self.config.challenge.min_clip_duration,
+ "max_duration": self.config.challenge.max_clip_duration,
+ }
+
+ try:
+ sampler_name = f"{media_type}_{modality}_sampler"
+ results = await self.cache_system.sample(sampler_name, 1, **kwargs)
+ except Exception as e:
+ bt.logging.error(f"Error sampling media with {sampler_name}: {e}")
+ return None
+
+ if not results or results.get("count", 0) == 0:
+ return None
+
+ sample = results["items"][0]
+
+ if multi_video:
+ try:
+ # for now we stitch up to 2 videos together
+ max_duration = (
+ self.config.challenge.max_clip_duration
+ - sample["segment"]["duration"]
+ )
+ results = await self.cache_system.sample(
+ sampler_name, 1, max_duration=max_duration
+ )
+ except Exception as e:
+ bt.logging.error(f"Error sampling media with {sampler_name}: {e}")
+ return None
+
+ if results and results.get("count", 0) > 0:
+ sample = {"sample_0": sample, "sample_1": results["items"][0]}
+ sample["video"] = (
+ sample["sample_0"]["video"],
+ sample["sample_1"]["video"],
+ )
+ del sample["sample_0"]["video"]
+ del sample["sample_1"]["video"]
+
+ if sample and sample.get(modality) is not None:
+ bt.logging.debug("Augmenting Media")
+ augmented_media, _, _ = apply_random_augmentations(
+ sample.get(modality),
+ (256, 256),
+ sample.get("mask_center", None),
+ )
+ sample[modality] = augmented_media
+ sample.update(
+ {
+ "modality": modality,
+ "media_type": media_type,
+ "label": MediaType(media_type).int_value,
+ }
+ )
+ return sample
+
+ return None
+
+ def determine_challenge_type(self):
+ """
+ Randomly selects a modality (image, video) and media type (real, synthetic, semisynthetic)
+ based on configured probabiltiies
+ """
+ modalities = [Modality.IMAGE.value, Modality.VIDEO.value]
+ modality = np.random.choice(
+ modalities,
+ p=[
+ self.config.challenge.image_prob,
+ self.config.challenge.video_prob,
+ ],
+ )
+
+ media_types = [
+ MediaType.REAL.value,
+ MediaType.SYNTHETIC.value,
+ MediaType.SEMISYNTHETIC.value,
+ ]
+ media_type = np.random.choice(
+ media_types,
+ p=[
+ self.config.challenge.real_prob,
+ self.config.challenge.synthetic_prob,
+ self.config.challenge.semisynthetic_prob,
+ ],
+ )
+
+ multi_video = (
+ modality == Modality.VIDEO
+ and np.random.rand() < self.config.challenge.multi_video_prob
+ )
+
+ return modality, media_type, multi_video
+
+ async def log_on_block(self, block):
+ """
+ Log information about validator state at regular intervals.
+
+ Args:
+ block: Current block number
+ """
+ blocks_till = self.config.epoch_length - (block % self.config.epoch_length)
+ bt.logging.info(
+ f"Forward Block: {self.subtensor.block} | Blocks till Set Weights: {blocks_till}"
+ )
+ print_info(
+ self.metagraph,
+ self.wallet.hotkey.ss58_address,
+ block,
+ )
+
+ if self.cache_system and block % 5 == 0:
+ try:
+ for name, sampler in self.cache_system.samplers.items():
+ count = sampler.get_available_count()
+ bt.logging.info(f"Cache status: {name} has {count} available items")
+
+ compressed_blocks = self.config.compressed_cache_update_interval - (
+ block % self.config.compressed_cache_update_interval
+ )
+ media_blocks = self.config.media_cache_update_interval - (
+ block % self.config.media_cache_update_interval
+ )
+ bt.logging.info(
+ f"Next compressed cache update in {compressed_blocks} blocks"
+ )
+ bt.logging.info(f"Next media cache update in {media_blocks} blocks")
+ except Exception as e:
+ bt.logging.error(f"Error logging cache status: {e}")
+
+ def log_challenge_results(self, media_sample, challenge_results, rewards):
+ uids = [d["uid"] for d in challenge_results]
+ results = {
+ "miner_uids": uids,
+ "miner_hotkeys": [d["hotkey"] for d in challenge_results],
+ "response_statuses": [d["status"] for d in challenge_results],
+ "response_errors": [d["error"] for d in challenge_results],
+ "predictions": [d["prediction"] for d in challenge_results],
+ "challenge_metadata": {
+ k: v for k, v in media_sample.items() if k != media_sample["modality"]
+ },
}
- with open(VALIDATOR_INFO_PATH, 'w') as f:
- yaml.safe_dump(validator_info, f, indent=4)
- bt.logging.info(f"Wrote validator info to {VALIDATOR_INFO_PATH}")
+ results["rewards"] = [rewards.get(uid, 0) for uid in uids]
+ results["scores"] = [self.eval_engine.scores[uid] for uid in uids]
+ results["metrics"] = [self.eval_engine.get_miner_metrics(uid) for uid in uids]
+ valid_indices = [
+ i for i in range(len(uids)) if not results["response_errors"][i]
+ ]
+ invalid_indices = [i for i in range(len(uids)) if results["response_errors"][i]]
-# The main function parses the configuration and runs the validator.
-if __name__ == "__main__":
- import warnings
- warnings.filterwarnings("ignore")
+ if self.config.netuid == MAINNET_UID:
+ for i in invalid_indices:
+ bt.logging.warning(
+ f"UID: {results['miner_uids'][i]} | "
+ f"HOTKEY: {results['miner_hotkeys'][i]} | "
+ f"STATUS: {results['response_statuses'][i]} | "
+ f"ERROR: {results['response_errors'][i]}"
+ )
- with Validator() as validator:
- while True:
- bt.logging.info(f"Validator running | uid {validator.uid} | {time.time()}")
- time.sleep(30)
+ for i in valid_indices:
+ bt.logging.success(
+ f"UID: {results['miner_uids'][i]} | "
+ f"HOTKEY: {results['miner_hotkeys'][i]} | "
+ f"PRED: {results['predictions'][i]}"
+ )
+ video_metrics = {
+ "video_" + k: f"{v:.4f}"
+ for k, v in results["metrics"][i]["video"].items()
+ }
+ video_metrics = [
+ f"{k.upper()}: {float(v)}" for k, v in video_metrics.items()
+ ]
+ image_metrics = {
+ "image_" + k: f"{v:.4f}"
+ for k, v in results["metrics"][i]["image"].items()
+ }
+ image_metrics = [
+ f"{k.upper()}: {float(v)}" for k, v in image_metrics.items()
+ ]
+ bt.logging.success(
+ f"{' | '.join(video_metrics)} | "
+ f"{' | '.join(image_metrics)} | "
+ f"REWARD: {results['rewards'][i]} | "
+ f"SCORE: {results['scores'][i]}"
+ )
+
+ bt.logging.info(json.dumps(results["challenge_metadata"], indent=2))
+
+ if not self.config.wandb_off:
+ self.wandb_logger.log(
+ media_sample=media_sample,
+ challenge_results=results,
+ )
+
+ def _validate_challenge_probs(self):
+ """
+ Validates that the challenge probabilities in config sum to 1.0.
+ """
+ total_modality = (
+ self.config.challenge.image_prob + self.config.challenge.video_prob
+ )
+ total_media = (
+ self.config.challenge.real_prob
+ + self.config.challenge.synthetic_prob
+ + self.config.challenge.semisynthetic_prob
+ )
+
+ if abs(total_modality - 1.0) > 1e-6:
+ raise ValueError(
+ f"Modality probabilities must sum to 1.0, got {total_modality} "
+ f"(image_prob={self.config.challenge.image_prob}, "
+ f"video_prob={self.config.challenge.video_prob})"
+ )
+
+ if abs(total_media - 1.0) > 1e-6:
+ raise ValueError(
+ f"Media type probabilities must sum to 1.0, got {total_media} "
+ f"(real_prob={self.config.challenge.real_prob}, "
+ f"synthetic_prob={self.config.challenge.synthetic_prob}, "
+ f"semisynthetic_prob={self.config.challenge.semisynthetic_prob})"
+ )
+
+ async def save_state(self):
+ """
+ Atomically save validator state (scores + miner history)
+ Maintains the current state and one backup.
+ """
+ self.lock_halt = True
+ while not self.lock_waiting:
+ sleep(self.config.neuron.lock_sleep_seconds)
+
+ try:
+ base_dir = self.config.neuron.full_path
+ os.makedirs(base_dir, exist_ok=True)
+
+ current_dir = os.path.join(base_dir, "state_current")
+ backup_dir = os.path.join(base_dir, "state_backup")
+ temp_dir = os.path.join(base_dir, "state_temp")
+
+ if os.path.exists(temp_dir):
+ shutil.rmtree(temp_dir)
+
+ os.makedirs(temp_dir)
+
+ # save to temp dir
+ self.eval_engine.save_state(temp_dir)
+ with open(os.path.join(temp_dir, "complete"), "w") as f:
+ f.write("1")
+
+ # backup current state
+ if os.path.exists(current_dir):
+ if os.path.exists(backup_dir):
+ shutil.rmtree(backup_dir)
+ os.rename(current_dir, backup_dir)
+
+ # move temp to current
+ os.rename(temp_dir, current_dir)
+
+ bt.logging.success("Saved validator state")
+
+ except Exception as e:
+ bt.logging.error(f"Error during state save: {str(e)}")
+ bt.logging.error(traceback.format_exc())
+ if os.path.exists(temp_dir):
+ shutil.rmtree(temp_dir)
+ finally:
+ self.lock_halt = False
+
+ async def load_state(self):
+ """
+ Load validator state, falling back to backup if needed.
+ """
+ base_dir = self.config.neuron.full_path
+ current_dir = os.path.join(base_dir, "state_current")
+ backup_dir = os.path.join(base_dir, "state_backup")
+
+ try:
+ if os.path.exists(current_dir) and os.path.exists(
+ os.path.join(current_dir, "complete")
+ ):
+ bt.logging.trace(
+ f"Attempting to load current validator state {current_dir}"
+ )
+ success = self.eval_engine.load_state(current_dir)
+ if success:
+ bt.logging.info("Successfully loaded current validator state")
+ return True
+ else:
+ bt.logging.warning("Failed to load current state, trying backup")
+ else:
+ bt.logging.warning(
+ "Current state not found or incomplete, trying backup"
+ )
+
+ # fall back to backup if needed
+ if os.path.exists(backup_dir) and os.path.exists(
+ os.path.join(backup_dir, "complete")
+ ):
+ current_time = time.time()
+ complete_marker = os.path.join(backup_dir, "complete")
+ marker_mod_time = os.path.getmtime(complete_marker)
+ backup_age_hours = (current_time - marker_mod_time) / 3600
+
+ max_age_hours = self.config.neuron.max_state_backup_hours
+ if backup_age_hours > max_age_hours:
+ bt.logging.warning(
+ f"Backup is {backup_age_hours:.2f} hours old (> {max_age_hours} hours), skipping load"
+ )
+ return False
+
+ bt.logging.trace(
+ f"Attempting to load backup validator state {backup_dir} (age: {backup_age_hours:.2f} hours)"
+ )
+ success = self.eval_engine.load_state(backup_dir)
+ if success:
+ bt.logging.info(
+ f"Successfully loaded backup validator state (age: {backup_age_hours:.2f} hours)"
+ )
+ return True
+ else:
+ bt.logging.error("Failed to load backup state")
+ return False
+ else:
+ bt.logging.warning("No valid state found")
+ return False
+ except Exception as e:
+ bt.logging.error(f"Error during state load: {str(e)}")
+ bt.logging.error(traceback.format_exc())
+ return False
+
+ async def shutdown(self):
+ """Shutdown the validator and clean up resources."""
+ bt.logging.info("Shutting down validator")
+
+
+if __name__ == "__main__":
+ try:
+ validator = Validator()
+ asyncio.run(validator.run())
+ except KeyboardInterrupt:
+ bt.logging.info("Validator interrupted by KeyboardInterrupt, shutting down")
+ except Exception as e:
+ bt.logging.error(f"Unhandled exception: {e}")
+ bt.logging.error(traceback.format_exc())
diff --git a/neurons/validator_proxy.py b/neurons/validator_proxy.py
deleted file mode 100644
index 73d50f90..00000000
--- a/neurons/validator_proxy.py
+++ /dev/null
@@ -1,318 +0,0 @@
-from fastapi import FastAPI, HTTPException, Depends, Request, status
-from fastapi.security import APIKeyHeader
-from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
-from cryptography.exceptions import InvalidSignature
-from concurrent.futures import ThreadPoolExecutor
-from PIL import Image
-from typing import Optional, Dict, List, Union, Any
-from dataclasses import dataclass
-from io import BytesIO
-from pathlib import Path
-import bittensor as bt
-import numpy as np
-import uvicorn
-import base64
-import tempfile
-import asyncio
-import cv2
-import os
-import httpx
-import time
-import socket
-from functools import lru_cache
-
-from bitmind.validator.config import TARGET_IMAGE_SIZE
-from bitmind.utils.image_transforms import get_base_transforms
-from bitmind.protocol import prepare_synapse
-from bitmind.utils.uids import get_random_uids
-from bitmind.validator.proxy import ProxyCounter
-
-# Constants
-AUTH_HEADER = APIKeyHeader(name="Authorization")
-FRAME_FORMAT = "RGB"
-DEFAULT_TIMEOUT = 9
-DEFAULT_SAMPLE_SIZE = 50
-
-
-class MediaProcessor:
- """Handles processing of images and videos"""
- def __init__(self, target_size: tuple):
- self.transforms = get_base_transforms(target_size)
-
- def process_image(self, b64_image: str) -> Any:
- """Process base64 encoded image"""
- image_bytes = base64.b64decode(b64_image)
- image = Image.open(BytesIO(image_bytes))
- return self.transforms(image)
-
- def process_video(self, video_data: bytes) -> List[Any]:
- """Process raw video bytes into transformed frames"""
- bt.logging.debug(f"Starting video processing with {len(video_data)} bytes")
-
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=True) as temp_file:
- bt.logging.debug(f"Created temp file: {temp_file.name}")
- temp_file.write(video_data)
- temp_file.flush()
-
- cap = cv2.VideoCapture(temp_file.name)
- if not cap.isOpened():
- bt.logging.error("Failed to open video stream")
- raise ValueError("Failed to open video stream")
- try:
- frames = []
- frame_count = 0
- while True:
- success, frame = cap.read()
- if not success:
- break
- frame_count += 1
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- pil_frame = Image.fromarray(rgb_frame)
- frames.append(pil_frame)
-
- bt.logging.debug(f"Extracted {frame_count} frames")
-
- if not frames:
- bt.logging.error("No frames extracted from video")
- raise ValueError("No frames extracted from video")
-
- transformed = self.transforms(frames)
- bt.logging.debug(f"Transformed frames shape: {type(transformed)}")
- return transformed
-
- except Exception as e:
- bt.logging.error(f"Error in video processing: {str(e)}")
- raise
- finally:
- cap.release()
-
-
-class PredictionService:
- """Handles interaction with miners for predictions"""
- def __init__(self, validator, dendrite):
- self.validator = validator
- self.dendrite = dendrite
- self.metagraph = validator.metagraph
-
- async def get_predictions(
- self,
- data: Any,
- modality: str,
- timeout: int = DEFAULT_TIMEOUT
- ) -> tuple[List[float], List[int]]:
- """Get predictions from miners"""
-
- miner_uids = self._get_miner_uids()
- s = time.time()
- predictions = await self.dendrite(
- axons=[self.metagraph.axons[uid] for uid in miner_uids],
- synapse=prepare_synapse(data, modality=modality),
- deserialize=True,
- run_async=True,
- timeout=timeout
- )
- valid_indices = [i for i, v in enumerate(predictions) if -1 not in v]
- if not valid_indices:
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="No valid predictions received"
- )
- bt.logging.info(f"Got {len(valid_indices)} organic respones in {time.time()-s:.6f}s")
- valid_preds = np.array(predictions)[valid_indices]
- valid_uids = np.array(miner_uids)[valid_indices]
-
- return [p[1] + p[2] for p in valid_preds], valid_uids.tolist()
-
- def _get_miner_uids(self) -> List[int]:
- """Get list of miner UIDs to query"""
- uids = self.validator.last_responding_miner_uids
- if not uids:
- bt.logging.warning("No recent miner UIDs found, sampling random UIDs")
- uids = get_random_uids(self.validator, k=DEFAULT_SAMPLE_SIZE)
- return uids
-
- def get_rich_data(self, uids: List[int]) -> Dict[str, List]:
- """Get additional miner metadata"""
- return {
- 'uids': [int(uid) for uid in uids],
- 'ranks': [float(self.metagraph.R[uid]) for uid in uids],
- 'incentives': [float(self.metagraph.I[uid]) for uid in uids],
- 'emissions': [float(self.metagraph.E[uid]) for uid in uids],
- 'hotkeys': [str(self.metagraph.hotkeys[uid]) for uid in uids],
- 'coldkeys': [str(self.metagraph.coldkeys[uid]) for uid in uids]
- }
-
-class ValidatorProxy:
- """FastAPI server that proxies requests to validator miners"""
- def __init__(self, validator):
- self.validator = validator
- self.media_processor = MediaProcessor(TARGET_IMAGE_SIZE)
- self.dendrite = bt.dendrite(wallet=validator.wallet)
- self.prediction_service = PredictionService(validator, self.dendrite)
- self.metrics = ProxyCounter(os.path.join(validator.config.neuron.full_path, "proxy_counter.json"))
- self.app = FastAPI(title="Validator Proxy", version="1.0.0")
- self._configure_routes()
-
- if self.validator.config.proxy.port:
- self.auth_verifier = self._setup_auth()
- self.start()
-
- def _configure_routes(self):
- """Configure FastAPI routes"""
- self.app.add_api_route(
- "/forward_image",
- self.handle_image_request,
- methods=["POST"],
- dependencies=[Depends(self.verify_auth)]
- )
- self.app.add_api_route(
- "/forward_video",
- self.handle_video_request,
- methods=["POST"],
- dependencies=[Depends(self.verify_auth)]
- )
- self.app.add_api_route(
- "/healthcheck",
- self.healthcheck,
- methods=["GET"],
- dependencies=[Depends(self.verify_auth)]
- )
-
- def _setup_auth(self) -> callable:
- """Set up authentication verifier using synchronous HTTP client"""
- with httpx.Client() as client:
- response = client.post(
- f"{self.validator.config.proxy.proxy_client_url}/get-credentials",
- json={
- "postfix": f":{self.validator.config.proxy.port}" if self.validator.config.proxy.port else "",
- "uid": self.validator.uid
- },
- timeout=DEFAULT_TIMEOUT
- )
- creds = response.json()
-
- signature = base64.b64decode(creds["signature"])
- message = creds["message"]
-
- def verify(key_bytes: bytes) -> bool:
- try:
- key = Ed25519PublicKey.from_public_bytes(key_bytes)
- key.verify(signature, message.encode())
- return True
- except InvalidSignature:
- return False
-
- return verify
-
- async def verify_auth(self, auth: str = Depends(AUTH_HEADER)) -> None:
- """Verify authentication token"""
- try:
- key_bytes = base64.b64decode(auth)
- if not self.auth_verifier(key_bytes):
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid authentication token"
- )
- except Exception as e:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=str(e)
- )
-
- async def handle_image_request(self, request: Request) -> Dict[str, Any]:
- """Handle image processing requests"""
- payload = await request.json()
-
- try:
- image = self.media_processor.process_image(payload['image'])
- except Exception as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Failed to process image: {str(e)}"
- )
-
- predictions, uids = await self.prediction_service.get_predictions(
- image,
- modality='image'
- )
-
- response = {
- 'preds': predictions,
- 'fqdn': socket.getfqdn()
- }
-
- # add rich data if requested
- if payload.get('rich', '').lower() == 'true':
- response.update(self.prediction_service.get_rich_data(uids))
-
- self.metrics.update(is_success=True)
- return response
-
- async def handle_video_request(self, request: Request) -> Dict[str, Any]:
- """Handle video processing requests"""
- try:
- form = await request.form()
- if "video" not in form:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Missing video file in form data"
- )
-
- video_file = form["video"]
- bt.logging.debug(f"Received video file of type: {type(video_file)}")
-
- video_data = await video_file.read()
- bt.logging.debug(f"Read video data of size: {len(video_data)} bytes")
-
- if not video_data:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Empty video file"
- )
-
- s = time.time()
- try:
- video = self.media_processor.process_video(video_data)
- bt.logging.debug(f"Processed video into {len(video)} frames")
- except Exception as e:
- bt.logging.error(f"Video processing error: {str(e)}")
- bt.logging.error(f"Video data type: {type(video_data)}")
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Failed to process video: {str(e)}"
- )
-
- bt.logging.info(f"finished processing video in {time.time() - s:.6f}s")
- predictions, uids = await self.prediction_service.get_predictions(
- video,
- modality='video',
- )
- bt.logging.debug(f"Got predictions of length: {len(predictions)}")
-
- response = {
- 'preds': predictions,
- 'fqdn': socket.getfqdn()
- }
-
- # add rich data if requested
- rich_param = form.get('rich', '').lower()
- if rich_param == 'true':
- response.update(self.prediction_service.get_rich_data(uids))
-
- self.metrics.update(is_success=True)
- return response
-
- except Exception as e:
- bt.logging.error(f"Unexpected error in handle_video_request: {str(e)}")
- raise
-
- async def healthcheck(self, request: Request) -> Dict[str, str]:
- """Health check endpoint"""
- return {'status': 'healthy'}
-
- def start(self):
- """Start the FastAPI server"""
- self.executor = ThreadPoolExecutor(max_workers=1)
- self.executor.submit(
- uvicorn.run, self.app, host="0.0.0.0", port=self.validator.config.proxy.port
- )
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..9478f052
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,49 @@
+[build-system]
+requires = ["setuptools>=64", "wheel", "pip>=21.0"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "bitmind"
+dynamic = ["version"]
+description = "SN34 on bittensor"
+authors = [
+ {name = "BitMind", email = "intern@bitmind.ai"}
+]
+readme = "README.md"
+requires-python = ">=3.10"
+license = {text = ""}
+urls = {homepage = "http://bitmind.ai"}
+
+dependencies = [
+ "bittensor==9.3.0",
+ "bittensor-cli==9.4.1",
+ "pillow==10.4.0",
+ "substrate-interface==1.7.11",
+ "numpy==2.0.1",
+ "pandas==2.2.3",
+ "torch==2.5.1",
+ "asyncpg==0.29.0",
+ "httpcore==1.0.7",
+ "httpx==0.28.1",
+ "pyarrow==19.0.1",
+ "ffmpeg-python==0.2.0",
+ "bitsandbytes==0.45.4",
+ "black==25.1.0",
+ "pre-commit==4.2.0",
+ "diffusers==0.33.1",
+ "transformers==4.50.0",
+ "scikit-learn==1.6.1",
+ "av==14.2.0",
+ "opencv-python==4.11.0.86",
+ "wandb==0.19.9",
+ "uvicorn==0.27.1",
+ "python-multipart==0.0.20",
+ "peft==0.15.0",
+ "hf_xet==1.1.1"
+]
+
+[tool.setuptools]
+packages = {find = {where = ["."], exclude = ["docs*", "wandb*", "*.egg-info"]}}
+
+[tool.setuptools.dynamic]
+version = {file = "VERSION"}
\ No newline at end of file
diff --git a/requirements-git.txt b/requirements-git.txt
new file mode 100644
index 00000000..c86921d2
--- /dev/null
+++ b/requirements-git.txt
@@ -0,0 +1 @@
+janus @ git+https://github.com/deepseek-ai/Janus.git
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 947e04ef..cc102792 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,43 +1,24 @@
-# Core ML frameworks
-bittensor==9.0.3
-torch==2.5.1
-torchvision==0.20.1
-torchaudio==2.5.1
-tensorflow==2.18.0
-tf-keras==2.18.0
-scikit-learn==1.5.2
-
-# Deep learning tools
-transformers==4.48.0
-diffusers==0.33.1
-accelerate==1.2.0
-bitsandbytes==0.45.0
-sentencepiece==0.2.0
-timm==1.0.12
-einops==0.8.0
-ultralytics==8.3.44
-janus @ git+https://github.com/deepseek-ai/Janus.git
-peft==0.15.0
-
-# Image/Video processing
-datasets==3.1.0
-opencv-python==4.10.0.84
+bittensor==9.3.0
+bittensor-cli==9.4.1
pillow==10.4.0
-imageio==2.36.1
-imageio-ffmpeg==0.5.1
-moviepy==2.1.1
-av==13.1.0
-dlib==19.24.6
+substrate-interface==1.7.11
+numpy==2.0.1
+pandas==2.2.3
+torch==2.5.1
+asyncpg==0.29.0
+httpcore==1.0.7
+httpx==0.28.1
+pyarrow==19.0.1
ffmpeg-python==0.2.0
-pyffmpeg==2.4.2.18.1
-imutils==0.5.4
-scikit-image==0.24.0
-
-# Data and logging
-wandb==0.19.0
-tensorboardx==2.6.2.2
-loguru==0.7.2
-httpx==0.27.0
-joblib==1.4.2
+bitsandbytes==0.45.4
+black==25.1.0
+pre-commit==4.2.0
+diffusers==0.33.1
+transformers==4.50.0
+scikit-learn==1.6.1
+av==14.2.0
+opencv-python==4.11.0.86
+wandb==0.19.9
+uvicorn==0.27.1
python-multipart==0.0.20
-strenum==0.4.15
+peft==0.15.0
diff --git a/run_neuron.py b/run_neuron.py
deleted file mode 100644
index 7904b256..00000000
--- a/run_neuron.py
+++ /dev/null
@@ -1,82 +0,0 @@
-"""
-Thank you to Namoray of SN19 for their autoupdate implementation!
-"""
-import os
-import sys
-import subprocess
-import time
-import argparse
-
-# self heal restart interval
-RESTART_INTERVAL_HOURS = 3
-
-
-def should_update_local(local_commit, remote_commit):
- return local_commit != remote_commit
-
-
-def run_auto_update_self_heal(neuron_type, auto_update, self_heal):
- last_restart_time = time.time()
-
- while True:
- time.sleep(60)
-
- if auto_update:
- current_branch = subprocess.getoutput("git rev-parse --abbrev-ref HEAD")
- local_commit = subprocess.getoutput("git rev-parse HEAD")
- os.system("git fetch")
- remote_commit = subprocess.getoutput(f"git rev-parse origin/{current_branch}")
-
- if should_update_local(local_commit, remote_commit):
- print("Local repo is not up-to-date. Updating...")
- reset_cmd = "git reset --hard " + remote_commit
- process = subprocess.Popen(reset_cmd.split(), stdout=subprocess.PIPE)
- output, error = process.communicate()
-
- if error:
- print("Error in updating:", error)
- else:
- print("Updated local repo to latest version: {}", format(remote_commit))
-
- print("Running the autoupdate steps...")
- # Trigger shell script. Make sure this file path starts from root
- os.system(f"./autoupdate_{neuron_type}_steps.sh")
- time.sleep(20)
- print("Finished running the autoupdate steps 😎")
- print("Restarting neuron")
- os.system(f"./start_{neuron_type}.sh")
- else:
- print("Repo is up-to-date.")
-
- if self_heal:
- # Check if it's time to restart the PM2 process
- if time.time() - last_restart_time >= RESTART_INTERVAL_HOURS * 3600:
- os.system(f"./start_{neuron_type}.sh")
- last_restart_time = time.time() # Reset the timer after the restart
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Bittensor neuron run script with optional self-healing and auto-update.")
- parser.add_argument("--validator", action="store_true")
- parser.add_argument("--miner", action="store_true")
- parser.add_argument("--no-self-heal", action="store_true", help="Disable the automatic restart of the PM2 process")
- parser.add_argument("--no-auto-update", action="store_true", help="Disable the automatic update of the local repository")
- parser.add_argument("--clear-cache", action="store_true", help="Clear the cache before starting validator")
-
- args = parser.parse_args()
- if not (args.miner ^ args.validator):
- print(f"Usage: python {__file__}" + "--validator | --miner [--no-self-heal --no-auto-update]")
- sys.exit(1)
-
- neuron_type = 'miner' if args.miner else 'validator'
-
- if args.clear_cache and args.validator:
- os.system(f"./start_{neuron_type}.sh --clear-cache")
- else:
- os.system(f"./start_{neuron_type}.sh")
-
- if not args.no_auto_update or not args.no_self_heal:
- run_auto_update_self_heal(
- neuron_type,
- auto_update=not args.no_auto_update,
- self_heal=not args.no_self_heal)
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 54812047..00000000
--- a/setup.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2023 Yuma Rao
-# TODO(developer): Set your name
-# Copyright © 2023
-
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import re
-import os
-import codecs
-import pathlib
-from os import path
-from io import open
-from setuptools import setup, find_packages
-from pkg_resources import parse_requirements
-
-
-def read_requirements(path):
- with open(path, "r") as f:
- requirements = f.read().splitlines()
- processed_requirements = []
-
- for req in requirements:
- # For git or other VCS links
- if req.startswith("git+") or "@" in req:
- pkg_name = re.search(r"(#egg=)([\w\-_]+)", req)
- if pkg_name:
- processed_requirements.append(pkg_name.group(2))
- else:
- # You may decide to raise an exception here,
- # if you want to ensure every VCS link has an #egg= at the end
- continue
- else:
- processed_requirements.append(req)
- return processed_requirements
-
-
-requirements = read_requirements("requirements.txt")
-here = path.abspath(path.dirname(__file__))
-
-with open(path.join(here, "README.md"), encoding="utf-8") as f:
- long_description = f.read()
-
-# loading version from setup.py
-with codecs.open(
- os.path.join(here, "bitmind/__init__.py"), encoding="utf-8"
-) as init_file:
- version_match = re.search(
- r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M
- )
- version_string = version_match.group(1)
-
-setup(
- name="bitmind_subnet",
- version=version_string,
- description="bitmind_subnet",
- long_description=long_description,
- long_description_content_type="text/markdown",
- url="https://github.com/BitMind-AI/bitmind-subnet",
- author="bittensor.com", # TODO(developer): Change this value to your module subnet author name.
- packages=find_packages(),
- include_package_data=True,
- author_email="", # TODO(developer): Change this value to your module subnet author email.
- license="MIT",
- python_requires=">=3.8",
- install_requires=requirements,
- classifiers=[
- "Development Status :: 3 - Alpha",
- "Intended Audience :: Developers",
- "Topic :: Software Development :: Build Tools",
- # Pick your license as you wish
- "License :: OSI Approved :: MIT License",
- "Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Topic :: Scientific/Engineering",
- "Topic :: Scientific/Engineering :: Mathematics",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Software Development",
- "Topic :: Software Development :: Libraries",
- "Topic :: Software Development :: Libraries :: Python Modules",
- ],
-)
diff --git a/setup.sh b/setup.sh
new file mode 100755
index 00000000..d14a0b54
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+###########################################
+# System Updates and Package Installation #
+###########################################
+
+# Update system
+sudo apt update -y
+
+# Remove old nodejs and npm if present
+sudo apt-get remove --purge -y nodejs npm
+
+# Install Node.js 20.x (LTS) from NodeSource for stability and universal standard
+# NOTE: Update the version here when a new LTS is released
+curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash -
+sudo apt-get install -y nodejs
+
+# Install core dependencies
+sudo apt install -y \
+ python3-pip \
+ nano \
+ libgl1 \
+ ffmpeg \
+ unzip
+
+# Install build dependencies
+sudo apt install -y \
+ build-essential \
+ cmake \
+ libopenblas-dev \
+ liblapack-dev \
+ libx11-dev \
+ libgtk-3-dev
+
+# Install process manager (pm2) globally
+sudo npm install -g pm2@latest
+
+############################
+# Python Package Installation
+############################
+
+pip install --use-pep517 -e . -r requirements-git.txt
+
diff --git a/setup_env.sh b/setup_env.sh
deleted file mode 100755
index 8b1b2470..00000000
--- a/setup_env.sh
+++ /dev/null
@@ -1,111 +0,0 @@
-#!/bin/bash
-
-###########################################
-# System Updates and Package Installation #
-###########################################
-
-# Update system
-sudo apt update -y
-
-# Install core dependencies
-sudo apt install -y \
- python3-pip \
- nano \
- libgl1 \
- npm \
- ffmpeg \
- unzip
-
-# Install build dependencies
-sudo apt install -y \
- build-essential \
- cmake \
- libopenblas-dev \
- liblapack-dev \
- libx11-dev \
- libgtk-3-dev
-
-# Install process manager
-sudo npm install -g pm2@latest
-
-############################
-# Python Package Installation
-############################
-
-pip install -e .
-pip install -r requirements.txt
-
-############################
-# Environment Files Setup #
-############################
-
-# Create miner.env if it doesn't exist
-if [ -f "miner.env" ]; then
- echo "File 'miner.env' already exists. Skipping creation."
-else
- cat > miner.env << 'EOL'
-# Default options
-#--------------------
-
-# Detector Configuration
-IMAGE_DETECTOR=CAMO # Options: CAMO, UCF, NPR, None
-IMAGE_DETECTOR_CONFIG=camo.yaml # Configs in base_miner/deepfake_detectors/configs
-VIDEO_DETECTOR=TALL # Options: TALL, None
-VIDEO_DETECTOR_CONFIG=tall.yaml # Configs in base_miner/deepfake_detectors/configs
-
-# Device Settings
-IMAGE_DETECTOR_DEVICE=cpu # Options: cpu, cuda
-VIDEO_DETECTOR_DEVICE=cpu
-
-# Subtensor Network Configuration
-NETUID=34 # Network User ID options: 34, 168
-SUBTENSOR_NETWORK=finney # Networks: finney, test, local
-SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
- # Endpoints:
- # - wss://entrypoint-finney.opentensor.ai:443
- # - wss://test.finney.opentensor.ai:443/
-
-# Wallet Configuration
-WALLET_NAME=default
-WALLET_HOTKEY=default
-
-# Miner Settings
-MINER_AXON_PORT=8091
-BLACKLIST_FORCE_VALIDATOR_PERMIT=True # Force validator permit for blacklisting
-EOL
- echo "File 'miner.env' created."
-fi
-
-# Create validator.env if it doesn't exist
-if [ -f "validator.env" ]; then
- echo "File 'validator.env' already exists. Skipping creation."
-else
- cat > validator.env << 'EOL'
-# Default options
-#--------------------
-
-# Subtensor Network Configuration
-NETUID=34 # Network User ID options: 34, 168
-SUBTENSOR_NETWORK=finney # Networks: finney, test, local
-SUBTENSOR_CHAIN_ENDPOINT=wss://entrypoint-finney.opentensor.ai:443
- # Endpoints:
- # - wss://entrypoint-finney.opentensor.ai:443
- # - wss://test.finney.opentensor.ai:443/
-
-# Wallet Configuration
-WALLET_NAME=default
-WALLET_HOTKEY=default
-
-# Validator Settings
-VALIDATOR_AXON_PORT=8092 # If using RunPod, must be >= 70000 for symmetric mapping
-VALIDATOR_PROXY_PORT=10913
-DEVICE=cuda
-
-# API Keys
-WANDB_API_KEY=your_wandb_api_key_here
-HUGGING_FACE_TOKEN=your_hugging_face_token_here
-EOL
- echo "File 'validator.env' created."
-fi
-
-echo "Environment setup completed successfully."
diff --git a/start_miner.sh b/start_miner.sh
index dd3c94f7..a58743f6 100755
--- a/start_miner.sh
+++ b/start_miner.sh
@@ -1,25 +1,70 @@
#!/bin/bash
+###################################
+# LOAD ENV FILE
+###################################
set -a
-source miner.env
+source .env.miner
set +a
-if pm2 list | grep -q "bitmind_miner"; then
- echo "Process 'bitmind_miner' is already running. Deleting it..."
- pm2 delete bitmind_miner
+###################################
+# PREPARE CLI ARGS
+###################################
+if [[ "$CHAIN_ENDPOINT" == *"test"* ]]; then
+ NETUID=168
+ NETWORK="test"
+elif [[ "$CHAIN_ENDPOINT" == *"finney"* ]]; then
+ NETUID=34
+ NETWORK="finney"
fi
-pm2 start neurons/miner.py --name bitmind_miner -- \
- --neuron.image_detector ${IMAGE_DETECTOR:-None} \
- --neuron.image_detector_config ${IMAGE_DETECTOR_CONFIG:-None} \
- --neuron.image_detector_device ${IMAGE_DETECTOR_DEVICE:-None} \
- --neuron.video_detector ${VIDEO_DETECTOR:-None} \
- --neuron.video_detector_config ${VIDEO_DETECTOR_CONFIG:-None} \
- --neuron.video_detector_device ${VIDEO_DETECTOR_DEVICE:-None} \
- --netuid $NETUID \
- --subtensor.network $SUBTENSOR_NETWORK \
- --subtensor.chain_endpoint $SUBTENSOR_CHAIN_ENDPOINT \
+case "$LOGLEVEL" in
+ "trace")
+ LOG_PARAM="--logging.trace"
+ ;;
+ "debug")
+ LOG_PARAM="--logging.debug"
+ ;;
+ "info")
+ LOG_PARAM="--logging.info"
+ ;;
+ *)
+ # Default to info if LOGLEVEL is not set or invalid
+ LOG_PARAM="--logging.info"
+ ;;
+esac
+
+# Set auto-update parameter based on AUTO_UPDATE
+FORCE_VPERMIT_PARAM=""
+if [ "$FORCE_VPERMIT" = false ]; then
+ FORCE_VPERMIT_PARAM="--no-force-validator-permit"
+fi
+
+
+###################################
+# RESTART PROCESSES
+###################################
+NAME="bitmind-miner"
+
+# Stop any existing processes
+if pm2 list | grep -q "$NAME"; then
+ echo "'$NAME' is already running. Deleting it..."
+ pm2 delete $NAME
+fi
+
+echo "Starting $NAME | chain_endpoint: $CHAIN_ENDPOINT | netuid: $NETUID"
+
+# Run data generator
+pm2 start neurons/miner.py \
+ --interpreter python3 \
+ --name $NAME \
+ -- \
--wallet.name $WALLET_NAME \
--wallet.hotkey $WALLET_HOTKEY \
- --axon.port $MINER_AXON_PORT \
- --blacklist.force_validator_permit $BLACKLIST_FORCE_VALIDATOR_PERMIT
+ --netuid $NETUID \
+ --subtensor.chain_endpoint $CHAIN_ENDPOINT \
+ --axon.port $AXON_PORT \
+ --axon.external_ip $AXON_EXTERNAL_IP \
+ --device $DEVICE \
+ $FORCE_VPERMIT_PARAM
+
diff --git a/start_validator.sh b/start_validator.sh
index c5be7244..e7345b1d 100755
--- a/start_validator.sh
+++ b/start_validator.sh
@@ -1,84 +1,147 @@
#!/bin/bash
-# Load environment variables from .env file & set defaults
+###################################
+# LOAD ENV FILE
+###################################
set -a
-source validator.env
+source .env.validator
set +a
-: ${VALIDATOR_PROXY_PORT:=10913}
-: ${DEVICE:=cuda}
-
-VALIDATOR_PROCESS_NAME="bitmind_validator"
-DATA_GEN_PROCESS_NAME="bitmind_data_generator"
-CACHE_UPDATE_PROCESS_NAME="bitmind_cache_updater"
-
-# Clear cache if specified
-while [[ $# -gt 0 ]]; do
- case $1 in
- --clear-cache)
- rm -rf ~/.cache/sn34
- shift
- ;;
- *)
- shift
- ;;
- esac
-done
-
+###################################
+# LOG IN TO THIRD PARTY SERVICES
+###################################
# Login to Weights & Biases
if ! wandb login $WANDB_API_KEY; then
echo "Failed to login to Weights & Biases with the provided API key."
exit 1
fi
+echo "Logged into W&B with API key provided in .env.validator"
# Login to Hugging Face
if ! huggingface-cli login --token $HUGGING_FACE_TOKEN; then
echo "Failed to login to Hugging Face with the provided token."
exit 1
fi
+echo "Logged into W&B with token provided in .env.validator"
-# STOP VALIDATOR PROCESS
-if pm2 list | grep -q "$VALIDATOR_PROCESS_NAME"; then
- echo "Process '$VALIDATOR_PROCESS_NAME' is already running. Deleting it..."
- pm2 delete $VALIDATOR_PROCESS_NAME
+###################################
+# PREPARE CLI ARGS
+###################################
+: ${PROXY_PORT:=10913}
+: ${PROXY_EXTERNAL_PORT:=$PROXY_PORT}
+: ${DEVICE:=cuda}
+
+if [[ "$CHAIN_ENDPOINT" == *"test"* ]]; then
+ NETUID=168
+ NETWORK="test"
+elif [[ "$CHAIN_ENDPOINT" == *"finney"* ]]; then
+ NETUID=34
+ NETWORK="finney"
fi
-# STOP REAL DATA CACHE UPDATER PROCESS
-if pm2 list | grep -q "$CACHE_UPDATE_PROCESS_NAME"; then
- echo "Process '$CACHE_UPDATE_PROCESS_NAME' is already running. Deleting it..."
- pm2 delete $CACHE_UPDATE_PROCESS_NAME
+case "$LOGLEVEL" in
+ "trace")
+ LOG_PARAM="--logging.trace"
+ ;;
+ "debug")
+ LOG_PARAM="--logging.debug"
+ ;;
+ "info")
+ LOG_PARAM="--logging.info"
+ ;;
+ *)
+ # Default to info if LOGLEVEL is not set or invalid
+ LOG_PARAM="--logging.info"
+ ;;
+esac
+
+# Set auto-update parameter based on AUTO_UPDATE
+if [ "$AUTO_UPDATE" = true ]; then
+ AUTO_UPDATE_PARAM=""
+else
+ AUTO_UPDATE_PARAM="--autoupdate-off"
fi
-# STOP SYNTHETIC DATA GENERATOR PROCESS
-if pm2 list | grep -q "$DATA_GEN_PROCESS_NAME"; then
- echo "Process '$DATA_GEN_PROCESS_NAME' is already running. Deleting it..."
- pm2 delete $DATA_GEN_PROCESS_NAME
+if [ "$HEARTBEAT" = true ]; then
+ HEARTBEAT_PARAM="--heartbeat"
+else
+ HEARTBEAT_PARAM=""
fi
+###################################
+# STOP AND WAIT FOR CLEANUP
+###################################
+VALIDATOR="sn34-validator"
+GENERATOR="sn34-generator"
+PROXY="sn34-proxy"
-WANDB_DIR="$(dirname "$(realpath "${BASH_SOURCE[0]}")")/wandb"
-echo "Pruning $WANDB_DIR"
-python3 bitmind/validator/scripts/prune_wandb_cache.py --dir $WANDB_DIR
+# Stop any existing processes
+if pm2 list | grep -q "$VALIDATOR"; then
+ echo "'$VALIDATOR' is already running. Deleting it..."
+ pm2 delete $VALIDATOR
+ sleep 1
+fi
-echo "Verifying access to synthetic image generation models. This may take a few minutes."
-if ! python3 bitmind/validator/verify_models.py; then
- echo "Failed to verify diffusion models. Please check the configurations or model access permissions."
- exit 1
+if pm2 list | grep -q "$GENERATOR"; then
+ echo "'$GENERATOR' is already running. Deleting it..."
+ pm2 delete $GENERATOR
+ sleep 2
fi
-echo "Starting validator process"
-pm2 start neurons/validator.py --name $VALIDATOR_PROCESS_NAME -- \
- --netuid $NETUID \
- --subtensor.network $SUBTENSOR_NETWORK \
- --subtensor.chain_endpoint $SUBTENSOR_CHAIN_ENDPOINT \
+if pm2 list | grep -q "$PROXY"; then
+ echo "'$PROXY' is already running. Deleting it..."
+ pm2 delete $PROXY
+ sleep 1
+fi
+
+
+###################################
+# START PROCESSES
+###################################
+SN34_CACHE_DIR=$(eval echo "$SN34_CACHE_DIR")
+
+echo "Starting validator and generator | chain_endpoint: $CHAIN_ENDPOINT | netuid: $NETUID"
+
+# Run data generator
+pm2 start neurons/generator.py \
+ --interpreter python3 \
+ --kill-timeout 2000 \
+ --name $GENERATOR \
+ -- \
--wallet.name $WALLET_NAME \
--wallet.hotkey $WALLET_HOTKEY \
- --axon.port $VALIDATOR_AXON_PORT \
- --proxy.port $VALIDATOR_PROXY_PORT
+ --netuid $NETUID \
+ --subtensor.chain_endpoint $CHAIN_ENDPOINT \
+ --cache-dir $SN34_CACHE_DIR \
+ --device $DEVICE
-echo "Starting real data cache updater process"
-pm2 start bitmind/validator/scripts/run_cache_updater.py --name $CACHE_UPDATE_PROCESS_NAME
+# Run validator
+pm2 start neurons/validator.py \
+ --interpreter python3 \
+ --kill-timeout 1000 \
+ --name $VALIDATOR \
+ -- \
+ --wallet.name $WALLET_NAME \
+ --wallet.hotkey $WALLET_HOTKEY \
+ --netuid $NETUID \
+ --subtensor.chain_endpoint $CHAIN_ENDPOINT \
+ --epoch-length 101 \
+ --cache-dir $SN34_CACHE_DIR \
+ --proxy.port $PROXY_PORT \
+ $LOG_PARAM \
+ $AUTO_UPDATE_PARAM \
+ $HEARTBEAT_PARAM
-echo "Starting synthetic data generation process"
-pm2 start bitmind/validator/scripts/run_data_generator.py --name $DATA_GEN_PROCESS_NAME -- \
- --device $DEVICE
+# Run validator proxy
+pm2 start neurons/proxy.py \
+ --interpreter python3 \
+ --kill-timeout 1000 \
+ --name $PROXY \
+ -- \
+ --wallet.name $WALLET_NAME \
+ --wallet.hotkey $WALLET_HOTKEY \
+ --netuid $NETUID \
+ --subtensor.chain_endpoint $CHAIN_ENDPOINT \
+ --proxy.port $PROXY_PORT \
+ --proxy.external_port $PROXY_EXTERNAL_PORT \
+ $LOG_PARAM
diff --git a/static/Vali-Arch.png b/static/Vali-Arch.png
deleted file mode 100644
index 65106f94..00000000
Binary files a/static/Vali-Arch.png and /dev/null differ
diff --git a/static/incentive.gif b/static/incentive.gif
deleted file mode 100644
index b7afc571..00000000
Binary files a/static/incentive.gif and /dev/null differ
diff --git a/tests/__init__.py b/tests/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py
deleted file mode 100644
index 5df4d9ff..00000000
--- a/tests/fixtures/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-NETUID = 34
diff --git a/tests/fixtures/image_transforms.py b/tests/fixtures/image_transforms.py
deleted file mode 100644
index e4f35a71..00000000
--- a/tests/fixtures/image_transforms.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from functools import partial
-import torchvision.transforms as transforms
-
-from bitmind.validator.config import TARGET_IMAGE_SIZE
-from bitmind.utils.image_transforms import (
- center_crop,
- RandomResizedCropWithParams,
- RandomHorizontalFlipWithParams,
- RandomVerticalFlipWithParams,
- RandomRotationWithParams,
- ConvertToRGB,
- ComposeWithParams,
- get_base_transforms,
- get_random_augmentations
-)
-
-
-TRANSFORMS = [
- center_crop,
- RandomHorizontalFlipWithParams,
- RandomVerticalFlipWithParams,
- partial(RandomRotationWithParams, degrees=20, interpolation=transforms.InterpolationMode.BILINEAR),
- partial(RandomResizedCropWithParams, size=TARGET_IMAGE_SIZE, scale=(0.2, 1.0), ratio=(1.0, 1.0)),
- ConvertToRGB
-]
-
-TRANSFORM_PIPELINES = [
- get_base_transforms(TARGET_IMAGE_SIZE),
- get_random_augmentations(TARGET_IMAGE_SIZE)
-]
\ No newline at end of file
diff --git a/tests/test_forward.py b/tests/test_forward.py
deleted file mode 100644
index 977e8c3e..00000000
--- a/tests/test_forward.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from types import SimpleNamespace
-import pytest
-
-from bitmind.utils.mock import MockValidator
-from bitmind.validator.forward import forward
-from tests.fixtures import NETUID
-
-
-@pytest.mark.asyncio
-@pytest.mark.parametrize("fake_prob", [1., 0.])
-async def test_validator_forward(fake_prob):
- print(f"Configuring mock config and mock validator, fake_prob = {fake_prob}")
- mock_config = SimpleNamespace(
- neuron=SimpleNamespace(
- prompt_type="annotation",
- sample_size=10,
- vpermit_tao_limit=1000
- ),
- wandb=SimpleNamespace(off=True),
- fake_prob=fake_prob,
- netuid=NETUID
- )
- mock_neuron = MockValidator(mock_config)
- print("Calling forward with mock validator")
- try:
- await forward(self=mock_neuron)
-
- except Exception as e:
- pytest.fail(f"validator forward raised an exception: {e}")
diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py
deleted file mode 100644
index 3e09a3e5..00000000
--- a/tests/test_image_transforms.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import pytest
-
-from bitmind.utils.mock import create_random_image
-from tests.fixtures.image_transforms import (
- TRANSFORMS,
- TRANSFORM_PIPELINES
-)
-
-
-@pytest.mark.parametrize("transform", TRANSFORMS)
-def test_create_transform(transform):
- tform = transform()
- assert tform is not None
-
-
-@pytest.mark.parametrize("transform", TRANSFORMS)
-def test_transform_has_expected_methods(transform):
- tform = transform()
- has_forward = hasattr(tform, 'forward')
- has_call = hasattr(tform, '__call__')
- assert has_call or has_forward
- if has_call:
- assert callable(getattr(tform, '__call__'))
- elif has_forward:
- assert callable(getattr(tform, 'forward'))
-
-
-@pytest.mark.parametrize("transform", TRANSFORMS)
-def test_invoke_transform(transform):
- image = create_random_image()
-
- try:
- transformed_image = transform()(image)
- except Exception as e:
- pytest.fail(f"transform pipeline invocation raised an exception: {e}")
-
- assert transformed_image is not None
-
-
-@pytest.mark.parametrize("transform_pipeline", TRANSFORM_PIPELINES)
-def test_invoke_transform_pipeline(transform_pipeline):
- image = create_random_image()
-
- try:
- transformed_image = transform_pipeline(image)
- except Exception as e:
- pytest.fail(f"transform pipeline invocation raised an exception: {e}")
-
- assert transformed_image is not None
-
-
diff --git a/tests/test_mock.py b/tests/test_mock.py
deleted file mode 100644
index 867aa944..00000000
--- a/tests/test_mock.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import pytest
-import asyncio
-import bittensor as bt
-from bitmind.protocol import prepare_image_synapse, ImageSynapse
-from bitmind.utils.mock import (
- MockDendrite,
- MockMetagraph,
- MockSubtensor,
- create_random_image
-)
-
-
-wallet = bt.MockWallet()
-wallet.create(coldkey_use_password=False)
-
-
-@pytest.mark.parametrize("netuid", [1, 2, 3])
-@pytest.mark.parametrize("n", [2, 4, 8, 16, 32, 64])
-@pytest.mark.parametrize("wallet", [wallet, None])
-def test_mock_subtensor(netuid, n, wallet):
- subtensor = MockSubtensor(netuid=netuid, n=n, wallet=wallet)
- neurons = subtensor.neurons(netuid=netuid)
- # Check netuid
- assert subtensor.subnet_exists(netuid)
- # Check network
- assert subtensor.network == "mock"
- assert subtensor.chain_endpoint == "mock_endpoint"
- # Check number of neurons
- assert len(neurons) == (n + 1 if wallet is not None else n)
- # Check wallet
- if wallet is not None:
- assert subtensor.is_hotkey_registered(
- netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
- )
-
- for neuron in neurons:
- assert type(neuron) == bt.NeuronInfo
- assert subtensor.is_hotkey_registered(
- netuid=netuid, hotkey_ss58=neuron.hotkey
- )
-
-
-@pytest.mark.parametrize("n", [16, 32, 64])
-def test_mock_metagraph(n):
- mock_subtensor = MockSubtensor(netuid=n, n=n)
- mock_metagraph = MockMetagraph(netuid=n, subtensor=mock_subtensor)
- # Check axons
- axons = mock_metagraph.axons
- assert len(axons) == n
- # Check ip and port
- for axon in axons:
- assert type(axon) == bt.AxonInfo
- assert axon.ip == mock_metagraph.default_ip
- assert axon.port == mock_metagraph.default_port
-
-
-def test_mock_reward_pipeline():
- pass
-
-
-def test_mock_neuron():
- pass
-
-
-@pytest.mark.parametrize("timeout", [0.1, 0.2])
-@pytest.mark.parametrize("min_time", [0, 0, 0])
-@pytest.mark.parametrize("max_time", [0.1, 0.15, 0.2])
-@pytest.mark.parametrize("n", [4, 16, 64])
-def test_mock_dendrite_timings(timeout, min_time, max_time, n):
- mock_wallet = bt.MockWallet(config=None)
- mock_dendrite = MockDendrite(mock_wallet)
- mock_dendrite.MIN_TIME = min_time
- mock_dendrite.MAX_TIME = max_time
- mock_subtensor = MockSubtensor(netuid=n, n=n)
- mock_metagraph = MockMetagraph(netuid=n, subtensor=mock_subtensor)
- axons = mock_metagraph.axons
-
- async def run():
- return await mock_dendrite(
- axons,
- synapse=prepare_image_synapse(create_random_image()),
- timeout=timeout,
- deserialize=False,
- )
-
- eps = 0.2
- responses = asyncio.run(run())
- for synapse in responses:
- assert (
- hasattr(synapse, "dendrite") and type(synapse.dendrite) == bt.TerminalInfo
- )
-
- dendrite = synapse.dendrite
- # check synapse.dendrite has (process_time, status_code, status_message)
- for field in ("process_time", "status_code", "status_message"):
- assert hasattr(dendrite, field) and getattr(dendrite, field) is not None
- # check that the dendrite take between min_time and max_time
- assert min_time <= dendrite.process_time
- assert dendrite.process_time <= max_time + eps
- # check that responses which take longer than timeout have 408 status code
- if dendrite.process_time >= timeout + eps:
- assert dendrite.status_code == 408
- assert dendrite.status_message == "Timeout"
- assert synapse.prediction == -1.
- # check that responses which take less than timeout have 200 status code
- elif dendrite.process_time < timeout:
- assert dendrite.status_code == 200
- assert dendrite.status_message == "OK"
- # check that outputs are not empty for successful responses
- assert (synapse.prediction >= 0.) & (synapse.prediction <= 1.)
- # dont check for responses which take between timeout and max_time because they are not guaranteed to have a status code of 200 or 408
- del mock_subtensor
- del mock_metagraph
\ No newline at end of file
diff --git a/tests/validator/test_generate_image.py b/tests/validator/test_generate_image.py
deleted file mode 100644
index f4cd1705..00000000
--- a/tests/validator/test_generate_image.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import pytest
-from unittest.mock import patch, MagicMock
-from bitmind.synthetic_data_generation.synthetic_data_generator import SyntheticDataGenerator
-from bitmind.validator.config import T2I_MODEL_NAMES
-from PIL import Image
-
-
-@pytest.fixture
-def mock_diffuser():
- """
- Fixture to mock the diffuser models (StableDiffusionXLPipeline and FluxPipeline).
-
- Returns:
- MagicMock: A mock object representing the diffuser pipeline.
- """
- with patch('bitmind.synthetic_image_generation.synthetic_image_generator.StableDiffusionXLPipeline') as mock_sdxl:
- with patch('bitmind.synthetic_image_generation.synthetic_image_generator.FluxPipeline') as mock_flux:
- mock_pipeline = MagicMock()
- test_image = Image.new('RGB', (256, 256))
- mock_pipeline.return_value = {"images": [test_image]}
- mock_pipeline.tokenizer_2 = MagicMock()
- mock_sdxl.from_pretrained.return_value = mock_pipeline
- mock_flux.from_pretrained.return_value = mock_pipeline
- yield mock_pipeline
-
-
-@pytest.fixture
-def mock_image_annotation_generator():
- """
- Fixture to mock the ImageAnnotationGenerator.
-
- Returns:
- MagicMock: A mock object representing the ImageAnnotationGenerator.
- """
- with patch('bitmind.synthetic_image_generation.synthetic_image_generator.ImageAnnotationGenerator') as mock:
- instance = mock.return_value
- instance.process_image.return_value = [{'description': 'A test caption'}]
- yield instance
-
-
-@pytest.mark.parametrize("diffuser_name", T2I_MODEL_NAMES)
-def test_generate_image_with_diffusers(mock_diffuser, mock_image_annotation_generator, diffuser_name):
- """
- Test the image generation process using different diffusion models.
-
- This test verifies the functionality of the SyntheticImageGenerator class,
- specifically its generate method. It checks the correct initialization,
- proper loading of the diffuser model, successful generation of synthetic images,
- and the correct structure and content of the generated data.
-
- Args:
- mock_diffuser (MagicMock): Mocked diffuser pipeline.
- mock_image_annotation_generator (MagicMock): Mocked image annotation generator.
- diffuser_name (str): Name of the diffuser model to test.
-
- The test performs the following:
- 1. Initializes SyntheticImageGenerator with specific parameters.
- 2. Mocks the load_diffuser method and sets the diffuser attribute.
- 3. Generates synthetic image data.
- 4. Verifies the structure and content of the generated data.
- 5. Checks for correct method calls and interactions.
-
- This test is useful for:
- - Validating the image generation process
- - Integration testing with different diffuser models
- """
- generator = SyntheticDataGenerator(
- prompt_type='annotation',
- use_random_diffuser=False,
- diffuser_name=diffuser_name
- )
-
- real_image = {'source': 'test_dataset', 'id': 'test_id'}
-
- with patch.object(generator, 'load_diffuser') as mock_load_diffuser:
- mock_load_diffuser.return_value = mock_diffuser
- generator.diffuser = mock_diffuser
-
- mock_output = MagicMock()
- mock_output.images = [Image.new('RGB', (256, 256))]
- mock_diffuser.return_value = mock_output
-
- with patch.object(generator, 'get_tokenizer_with_min_len') as mock_get_tokenizer:
- mock_get_tokenizer.return_value = (MagicMock(), 77)
-
- generated_data = generator.generate(k=1, real_images=[real_image])
-
- assert isinstance(generated_data, list)
- assert len(generated_data) == 1
-
- image_data = generated_data[0]
- assert isinstance(image_data, dict)
- assert 'prompt' in image_data
- assert 'image' in image_data
- assert 'id' in image_data
- assert 'gen_time' in image_data
-
- assert isinstance(image_data['image'], Image.Image)
- assert image_data['prompt'] == "A test caption"
- assert isinstance(image_data['id'], str)
- assert isinstance(image_data['gen_time'], float)
-
- mock_image_annotation_generator.process_image.assert_called_once()
- mock_load_diffuser.assert_called_once_with(diffuser_name)
-
- assert image_data['image'].size == (256, 256)
- assert image_data['image'].mode == 'RGB'
-
-
-if __name__ == "__main__":
- pytest.main()
diff --git a/tests/validator/test_verify_models.py b/tests/validator/test_verify_models.py
deleted file mode 100644
index 8e669692..00000000
--- a/tests/validator/test_verify_models.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import pytest
-import os
-from unittest.mock import patch, MagicMock, call
-from bitmind.validator.verify_models import is_model_cached, main
-from bitmind.validator.config import T2I_MODEL_NAMES, IMAGE_ANNOTATION_MODEL, TEXT_MODERATION_MODEL
-
-@pytest.fixture
-def mock_expanduser():
- """
- Fixture to mock os.path.expanduser.
-
- Returns:
- MagicMock: A mock object representing os.path.expanduser.
- """
- with patch('os.path.expanduser') as mock:
- mock.return_value = '/mock/home/.cache/huggingface/'
- yield mock
-
-@pytest.fixture
-def mock_isdir():
- """
- Fixture to mock os.path.isdir.
-
- Returns:
- MagicMock: A mock object representing os.path.isdir.
- """
- with patch('os.path.isdir') as mock:
- yield mock
-
-def test_is_model_cached(mock_expanduser, mock_isdir):
- """
- Test the is_model_cached function.
-
- This test verifies the functionality of the is_model_cached function
- under different scenarios of model caching.
-
- Args:
- mock_expanduser (MagicMock): Mocked os.path.expanduser function.
- mock_isdir (MagicMock): Mocked os.path.isdir function.
-
- The test performs the following:
- 1. Checks if the function correctly identifies a cached model.
- 2. Verifies the correct path is checked for the model cache.
- 3. Checks if the function correctly identifies a non-cached model.
-
- This test is useful for:
- - Ensuring correct behavior of model cache checking
- - Verifying path construction for model cache
- """
- mock_isdir.return_value = True
- assert is_model_cached('test/model') == True
- mock_isdir.assert_called_with('/mock/home/.cache/huggingface/models--test--model')
-
- mock_isdir.return_value = False
- assert is_model_cached('test/model') == False
-
-@patch('bitmind.validator.verify_models.SyntheticImageGenerator')
-@patch('bitmind.validator.verify_models.is_model_cached')
-def test_main(mock_is_model_cached, MockSyntheticImageGenerator):
- """
- Test the main function of the verify_models module.
-
- This test verifies the behavior of the main function, particularly
- its interaction with SyntheticImageGenerator under various scenarios.
-
- Args:
- mock_is_model_cached (MagicMock): Mocked is_model_cached function.
- MockSyntheticImageGenerator (MagicMock): Mocked SyntheticImageGenerator class.
-
- The test performs the following:
- 1. Simulates a scenario where no models are cached.
- 2. Calls the main function.
- 3. Verifies that SyntheticImageGenerator is called with correct parameters
- for different model types and diffuser names.
-
- This test is useful for:
- - Ensuring correct initialization of SyntheticImageGenerator
- - Verifying handling of different model types and diffuser names
- - Regression testing for the main function's behavior
- """
- # Setup mock_is_model_cached to simulate caching behavior
- # Assume no models are cached
- mock_is_model_cached.side_effect = lambda model_name: False
-
- # Call the main function
- main()
-
- # Expected calls with varying parameters based on model type
- expected_calls = [
- call(prompt_type='annotation', use_random_diffuser=True, diffuser_name=None), # For IMAGE_ANNOTATION_MODEL and TEXT_MODERATION_MODEL
- *[call(prompt_type='annotation', use_random_diffuser=False, diffuser_name=name) for name in T2I_MODEL_NAMES] # For each name in T2I_MODEL_NAMES
- ]
-
- # Verify all calls to SyntheticImageGenerator with the correct parameters
- MockSyntheticImageGenerator.assert_has_calls(expected_calls, any_order=True)
-
-if __name__ == "__main__":
- pytest.main()
\ No newline at end of file