diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..68b4fed --- /dev/null +++ b/environment.yml @@ -0,0 +1,199 @@ +name: evo_env +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - bottleneck=1.3.7=py312ha883a20_0 + - brotli=1.0.9=h5eee18b_8 + - brotli-bin=1.0.9=h5eee18b_8 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2024.7.2=h06a4308_0 + - et_xmlfile=1.1.0=py312h06a4308_1 + - expat=2.6.2=h6a678d5_0 + - freetype=2.12.1=h4a9f257_0 + - intel-openmp=2023.1.0=hdb19cb5_46306 + - joblib=1.4.2=py312h06a4308_0 + - jpeg=9e=h5eee18b_1 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libbrotlicommon=1.0.9=h5eee18b_8 + - libbrotlidec=1.0.9=h5eee18b_8 + - libbrotlienc=1.0.9=h5eee18b_8 + - libdeflate=1.17=h5eee18b_1 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtiff=4.5.1=h6a678d5_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp-base=1.3.2=h5eee18b_0 + - lz4-c=1.9.4=h6a678d5_1 + - matplotlib-base=3.8.4=py312h526ad5a_0 + - mkl=2023.1.0=h213fc3f_46344 + - mkl-service=2.4.0=py312h5eee18b_1 + - mkl_fft=1.3.8=py312h5eee18b_0 + - mkl_random=1.2.4=py312hdb19cb5_0 + - ncurses=6.4=h6a678d5_0 + - numexpr=2.8.7=py312hf827012_0 + - numpy=1.26.4=py312hc5e2394_0 + - numpy-base=1.26.4=py312h0da6c21_0 + - openjpeg=2.4.0=h9ca470c_2 + - openpyxl=3.1.2=py312h5eee18b_0 + - openssl=3.0.14=h5eee18b_0 + - packaging=24.1=py312h06a4308_0 + - pandas=2.2.2=py312h526ad5a_0 + - pillow=10.4.0=py312h5eee18b_0 + - pip=24.0=py312h06a4308_0 + - pybind11-abi=5=hd3eb1b0_0 + - python=3.12.4=h5148396_1 + - python-dateutil=2.9.0post0=py312h06a4308_2 + - python-tzdata=2023.3=pyhd3eb1b0_0 + - pytz=2024.1=py312h06a4308_0 + - readline=8.2=h5eee18b_0 + - scipy=1.13.1=py312hc5e2394_0 + - seaborn=0.13.2=py312h06a4308_0 + - setuptools=69.5.1=py312h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.45.3=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.14=h39e8969_0 + - unicodedata2=15.1.0=py312h5eee18b_0 + - wheel=0.43.0=py312h06a4308_0 + - xz=5.4.6=h5eee18b_1 + - zlib=1.2.13=h5eee18b_1 + - zstd=1.5.5=hc292b87_2 + - pip: + - accelerate==0.32.1 + - anyio==4.4.0 + - argon2-cffi==23.1.0 + - argon2-cffi-bindings==21.2.0 + - arrow==1.3.0 + - asttokens==2.4.1 + - async-lru==2.0.4 + - attrs==23.2.0 + - babel==2.15.0 + - beautifulsoup4==4.12.3 + - bleach==6.1.0 + - certifi==2024.7.4 + - cffi==1.16.0 + - charset-normalizer==3.3.2 + - comm==0.2.2 + - contourpy==1.2.1 + - cycler==0.12.1 + - debugpy==1.8.2 + - decorator==5.1.1 + - defusedxml==0.7.1 + - executing==2.0.1 + - fair-esm==2.0.1 + - fastjsonschema==2.20.0 + - filelock==3.15.4 + - fonttools==4.53.1 + - fqdn==1.5.1 + - fsspec==2024.6.1 + - h11==0.14.0 + - httpcore==1.0.5 + - httpx==0.27.0 + - huggingface-hub==0.23.5 + - idna==3.7 + - ipykernel==6.29.5 + - ipython==8.26.0 + - ipywidgets==8.1.3 + - isoduration==20.11.0 + - jedi==0.19.1 + - jinja2==3.1.4 + - json5==0.9.25 + - jsonpointer==3.0.0 + - jsonschema==4.23.0 + - jsonschema-specifications==2023.12.1 + - jupyter==1.0.0 + - jupyter-client==8.6.2 + - jupyter-console==6.6.3 + - jupyter-core==5.7.2 + - jupyter-events==0.10.0 + - jupyter-lsp==2.2.5 + - jupyter-server==2.14.2 + - jupyter-server-terminals==0.5.3 + - jupyterlab==4.2.3 + - jupyterlab-pygments==0.3.0 + - jupyterlab-server==2.27.3 + - jupyterlab-widgets==3.0.11 + - kiwisolver==1.4.5 + - markupsafe==2.1.5 + - matplotlib==3.9.1 + - matplotlib-inline==0.1.7 + - mistune==3.0.2 + - mpmath==1.3.0 + - nbclient==0.10.0 + - nbconvert==7.16.4 + - nbformat==5.10.4 + - nest-asyncio==1.6.0 + - networkx==3.3 + - notebook==7.2.1 + - notebook-shim==0.2.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.20.5 + - nvidia-nvjitlink-cu12==12.5.82 + - nvidia-nvtx-cu12==12.1.105 + - overrides==7.7.0 + - pandocfilters==1.5.1 + - parso==0.8.4 + - pexpect==4.9.0 + - platformdirs==4.2.2 + - prometheus-client==0.20.0 + - prompt-toolkit==3.0.47 + - psutil==6.0.0 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pycparser==2.22 + - pygments==2.18.0 + - pyparsing==3.1.2 + - python-json-logger==2.0.7 + - pyyaml==6.0.1 + - pyzmq==26.0.3 + - qtconsole==5.5.2 + - qtpy==2.4.1 + - referencing==0.35.1 + - regex==2024.5.15 + - requests==2.32.3 + - rfc3339-validator==0.1.4 + - rfc3986-validator==0.1.1 + - rpds-py==0.19.0 + - safetensors==0.4.3 + - send2trash==1.8.3 + - sniffio==1.3.1 + - soupsieve==2.5 + - stack-data==0.6.3 + - sympy==1.13.0 + - terminado==0.18.1 + - tinycss2==1.3.0 + - tokenizers==0.15.2 + - torch==2.3.1 + - tornado==6.4.1 + - tqdm==4.66.4 + - traitlets==5.14.3 + - transformers==4.38.0 + - types-python-dateutil==2.9.0.20240316 + - typing-extensions==4.12.2 + - tzdata==2024.1 + - uri-template==1.3.0 + - urllib3==2.2.2 + - wcwidth==0.2.13 + - webcolors==24.6.0 + - webencodings==0.5.1 + - websocket-client==1.8.0 + - widgetsnbextension==4.0.11 +prefix: /home/carlos/miniconda3/envs/evo_env diff --git a/evo_prot_grad/common/sampler.py b/evo_prot_grad/common/sampler.py index 77d04bc..278a03c 100644 --- a/evo_prot_grad/common/sampler.py +++ b/evo_prot_grad/common/sampler.py @@ -7,6 +7,34 @@ import pandas as pd import gc +def convert_full_to_relative_sequences(full_sequences, ref_seq): + """ + Convert a list of full sequences into relative sequence format based on a reference sequence. + + Parameters: + full_sequences (list): List of full sequences (strings) to be converted. + ref_seq (str): Reference sequence (string) to compare against. + + Returns: + list: List of relative sequences in the format 'A1B-A3C' where A is the reference amino acid, + 1 is the site index (1-based), and B is the mutant amino acid. + """ + relative_sequences = [] + + for seq in full_sequences: + mutations = [] + for i, (ref_aa, mut_aa) in enumerate(zip(ref_seq, seq)): + #print((ref_aa, mut_aa)) + if ref_aa != mut_aa: + mutation = f"{ref_aa}{i+1}{mut_aa}" + mutations.append(mutation) + relative_sequences.append('-'.join(mutations)) + + return relative_sequences + +def prep_seqs(seqs, ref_seq): + return [convert_full_to_relative_sequences([seq.replace(" ", "")], ref_seq.replace(" ", "")) for seq in seqs] + class DirectedEvolution: """Main class for plug and play directed evolution with gradient-based discrete MCMC. """ @@ -134,7 +162,8 @@ def _product_of_experts(self, inputs: List[str]) -> Tuple[List[torch.Tensor], to oh, score = expert(inputs) ohs += [oh] scores += [expert.temperature * score] - # sum scores over experts + inputs2 = [seq.replace(" ", "") for seq in inputs] + scores2 = [score.detach().cpu().numpy() for score in scores] return ohs, torch.stack(scores, dim=0).sum(dim=0) @@ -184,6 +213,8 @@ def prepare_results(self, variants, scores, n_seqs_to_keep=None): sequence_scores = {} # Iterate through the flattened list to count sequences and record first appearance + if len(scores.shape) == 1: # If scores has a single dimension + scores = scores.reshape(-1, 1) # Reshape to have two dimensions for i, sublist in enumerate(variants): for j, seq in enumerate(sublist): flat_seq = ''.join(seq.split()) @@ -242,6 +273,11 @@ def save_results(self, filename, variants, scores, n_seqs_to_keep=10000): for key, value in params.items(): f.write(f'{key}: {value}\n') + def _recompute_score(self, seq): + ohs, score = self.experts[0]([seq]) + return score.detach().flatten().cpu().numpy() + + def __call__(self) -> Tuple[List[str], np.ndarray]: """ Run the gradient-based MCMC sampler. @@ -261,6 +297,9 @@ def __call__(self) -> Tuple[List[str], np.ndarray]: pos_mask = pos_mask.reshape(self.parallel_chains,-1) for i in range(self.n_steps): + print(f"#### Starting iteration {i}.") + start_seqs = self.chains.copy() + self.chains = self.canonical_chain_tokenizer.decode(cur_chains_oh) # to reflect any reset chains that reached multiple mutations ###### sample path length U = torch.randint(1, 2 * self.max_pas_path_length, size=(self.parallel_chains,1)) max_u = int(torch.max(U).item()) @@ -274,7 +313,6 @@ def __call__(self) -> Tuple[List[str], np.ndarray]: # Need to use the string version of the chain to pass to experts ohs, PoE = self._product_of_experts(self.chains) grad_x = self._compute_gradients(ohs, PoE) - # do U intermediate steps with torch.no_grad(): for step in range(max_u): @@ -311,12 +349,12 @@ def __call__(self) -> Tuple[List[str], np.ndarray]: row_select = changes_all.sum(-1).unsqueeze(-1) # [n_chains,seq_len,1] new_x = cur_chains_oh * (1.0 - row_select) + changes_all cur_u_mask = u_mask[:, step].unsqueeze(-1).unsqueeze(-1) - cur_chains_oh = cur_u_mask * new_x + (1 - cur_u_mask) * cur_chains_oh + cur_chains_oh2 = cur_u_mask * new_x + (1 - cur_u_mask) * cur_chains_oh - y = cur_chains_oh + y = cur_chains_oh2.clone() # last step - y_strs = self.canonical_chain_tokenizer.decode(y) + y_strs = self.canonical_chain_tokenizer.decode(y) # Created string version of potentially new seqs. Compare them to old seqs ohs, proposed_PoE = self._product_of_experts(y_strs) grad_y = self._compute_gradients(ohs, proposed_PoE) grad_y = grad_y.detach() @@ -336,20 +374,32 @@ def __call__(self) -> Tuple[List[str], np.ndarray]: #log_acc = log_backwd - log_fwd m_term = (proposed_PoE.squeeze() - PoE.squeeze()) log_acc = m_term + log_ratio - #print(f"log_acc has shape {log_acc}, m_term has shape {m_term.shape}, and log_ratio has shape {log_ratio.shape}.") accepted = (log_acc.exp() >= torch.rand_like(log_acc)).float().view(-1, *([1] * x_rank)) # original - #accepted = (log_acc.exp() >= torch.rand_like(log_acc)).float().view(-1, 1, 1) - #print(f"y has shape {y.shape}, and accepted has shape {accepted.shape}") - cur_chains_oh = y * accepted + (1.0 - accepted) * cur_chains_oh - + # handle with a for loop + accepted, PoE = accepted.squeeze(), PoE.squeeze().clone() + + dec_seqs = self.canonical_chain_tokenizer.decode(cur_chains_oh2) + for i in range(len(accepted)): + if accepted[i] == 1: + cur_chains_oh[i, :, :] = y[i, :, :] + PoE[i] = proposed_PoE[i] + self.chains[i] = self.canonical_chain_tokenizer.decode(cur_chains_oh[i, :, :].unsqueeze(0))[0] + #cur_chains_oh2 = y * accepted + (1.0 - accepted) * cur_chains_oh + #for chain_idx in range(self.parallel_chains): + # if accepted.squeeze()[chain_idx] == 0: + # # Compare cur_chains_oh and cur_chains_oh2 for chains where accepted is 0 + # if not torch.equal(cur_chains_oh[chain_idx], cur_chains_oh2[chain_idx]): + #cur_chains_oh = cur_chains_oh2 + # Current chain state book-keeping - self.chains_oh = cur_chains_oh - self.chains = self.canonical_chain_tokenizer.decode(cur_chains_oh) + self.chains_oh = cur_chains_oh.clone() + # Check that cur_chains_oh and self.chains are synchronized + decoded_sequences = self.canonical_chain_tokenizer.decode(self.chains_oh) + dec_seqs = self.canonical_chain_tokenizer.decode(self.chains_oh) # History book-keeping self.chains_oh_history += [cur_chains_oh.clone()] - PoE = proposed_PoE.squeeze() * accepted.squeeze() + PoE.squeeze() * (1. - accepted.squeeze()) self.PoE_history += [PoE.clone()] - + if self.verbose: x_strs = self.canonical_chain_tokenizer.decode(cur_chains_oh) for idx in range(log_acc.size(0)): @@ -364,7 +414,7 @@ def __call__(self) -> Tuple[List[str], np.ndarray]: mask_flag = (dist >= self.max_mutations).bool() mask_flag = mask_flag.reshape(self.parallel_chains) cur_chains_oh[mask_flag] = self.wt_oh - + if i > 10 and i % 100 == 0: print(f"Finished step {i} out of {self.n_steps}.") if torch.cuda.is_available(): @@ -376,8 +426,11 @@ def __call__(self) -> Tuple[List[str], np.ndarray]: scores_ = self.PoE_history[-1].detach().cpu().numpy() elif self.output == 'all': output_ = [] - for i in range(len(self.chains_oh_history)): - output_ += [ self.canonical_chain_tokenizer.decode(self.chains_oh_history[i]) ] + for j in range(len(self.chains_oh_history)): + decoded_sequences = self.canonical_chain_tokenizer.decode(self.chains_oh_history[j]) + scores = self.PoE_history[j].detach().cpu().numpy() # Convert PoE to numpy for easier handling + seqs = prep_seqs(decoded_sequences, self.wtseq) + output_ += [ decoded_sequences ] scores_ = torch.stack(self.PoE_history).detach().cpu().numpy() elif self.output == 'best': best_idxs = torch.stack(self.PoE_history).argmax(0) diff --git a/evo_prot_grad/experts/seq2fitness_expert.py b/evo_prot_grad/experts/seq2fitness_expert.py new file mode 100644 index 0000000..c0237ed --- /dev/null +++ b/evo_prot_grad/experts/seq2fitness_expert.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn +from typing import List, Tuple, Dict, Any, Optional +from evo_prot_grad.experts.base_experts import Expert +from evo_prot_grad.common.embeddings import OneHotEmbedding +from sequence_utils import rel_sequences_to_dict, create_absolute_sequences_list, convert_rel_seqs_to_tensors, pad_rel_seq_tensors_with_nan, list_of_dicts_from_list_of_sequence_strings +from seq2fitness_models import ProteinFunctionPredictor_with_probmatrix, compute_model_scores +from seq2fitness_traintools import ModelCheckpoint + + +class CustomProteinExpert(Expert): + def __init__(self, model_path: str, temperature: float, device: str = 'cpu', task_weights: Optional[Dict] = None): + """ + Custom expert class for the ProteinFunctionPredictor_with_probmatrix model. + Args: + model_path (str): Path to the model checkpoint. + temperature (float): Temperature for sampling from the expert. + device (str): The device to use for the expert. Defaults to 'cpu'. + """ + self.model, self.model_params, _ = ModelCheckpoint.load_model(model_path) + self.tokenizer = self.model.alphabet.get_batch_converter() + # Set requires_grad to True for ESM model parameters + for param in self.model.esm_model.parameters(): + param.requires_grad = True + # Apply OneHotEmbedding to the model's ESM embedding layer after loading the model + self.model.esm_model.embed_tokens = OneHotEmbedding(self.model.esm_model.embed_tokens) + self.model.esm_model.embed_tokens.weight = nn.Parameter(self.model.esm_model.embed_tokens.weight.half()) + self.model.eval() + self.task_weights = task_weights + + self.temperature = temperature + self.device = device + self.model.to(device) + + vocab = {char: idx for idx, char in enumerate(self.model.alphabet.all_toks)} + scoring_strategy = "dummy" # Placeholder as we do not use variant_scoring + super().__init__(temperature, self.model, vocab, scoring_strategy, device) + + self._wt_oh = None + + def _get_last_one_hots(self) -> torch.Tensor: + """Returns the one-hot tensors most recently passed as input.""" + return self.model.esm_model.embed_tokens.one_hots + + def move_to_device(self, batch: Dict[str, Any], device: str) -> Dict[str, Any]: + """Move batch to the specified device.""" + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.to(device) + elif isinstance(value, dict): # In case of nested dictionaries + batch[key] = {k: v.to(device) for k, v in value.items() if isinstance(v, torch.Tensor)} + return batch + + def prepare_batch(self, inputs: List[str], device: str): + """Prepare batch from absolute sequences """ + #print(f"inputs is {inputs}") + _, rel_seqs_list_of_dicts = list_of_dicts_from_list_of_sequence_strings(inputs, self.model.ref_seq) + #print(f"rel_seqs_list_of_dicts is {rel_seqs_list_of_dicts}.") + batch_labels, batch_strs, batch_tokens = self.tokenizer([(str(i), seq) for i, seq in enumerate(inputs)]) + batch_tokens = batch_tokens.to(device) + rel_seqs_tensors = convert_rel_seqs_to_tensors(rel_seqs_list_of_dicts) + rel_seqs_tensors_padded = pad_rel_seq_tensors_with_nan(rel_seqs_tensors) + return batch_tokens, rel_seqs_tensors_padded + + def tokenize(self, inputs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """Tokenizes a list of protein sequences.""" + return self.prepare_batch(inputs, self.device) + + def get_model_output(self, inputs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns the one-hot sequences and model predictions for each amino acid in the input sequence.""" + inputs = [seq.replace(" ", "") for seq in inputs] # since EvoProtGrad adds spaces between amino acids, sigh + batch_tokens, rel_seqs_tensors_padded = self.prepare_batch(inputs, self.device) + self.move_to_device({'tokens': batch_tokens, 'rel_seqs': rel_seqs_tensors_padded}, self.device) + predictions = self.model(batch_tokens, rel_seqs_tensors_padded) + oh = self._get_last_one_hots() + return oh, predictions + + def __call__(self, inputs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns the one-hot sequences and expert score.""" + oh, predictions = self.get_model_output(inputs) + scores = compute_model_scores(predictions, self.task_weights) + return oh.float(), scores.float() + + def init_wildtype(self, wt_seq: str) -> None: + """Set the one-hot encoded wildtype sequence for this expert.""" + self._wt_oh, self._wt_preds = self.get_model_output([wt_seq]) + self._wt_score = compute_model_scores(self._wt_preds, self.task_weights).detach().cpu().numpy() + print(f"Wt sequence has score {float(self._wt_score):.4g}.") + +def build_custom_expert(model_path: str, temperature: float, device: str = 'cpu', task_weights: Optional[Dict] = None) -> CustomProteinExpert: + """Builds a CustomProteinExpert from a checkpoint.""" + return CustomProteinExpert(model_path=model_path, temperature=temperature, device=device, task_weights=task_weights) + \ No newline at end of file