Skip to content

torchsim

TorchSim calculators for molecular properties.

FairChemCalculator dataclass

Bases: TorchSimCalculator, MachineLearnedInteratomicPotentialCalculator, MSONable


              flowchart TD
              jfchemistry.calculators.torchsim.FairChemCalculator[FairChemCalculator]
              jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator[TorchSimCalculator]
              jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator[MachineLearnedInteratomicPotentialCalculator]
              jfchemistry.calculators.base.Calculator[Calculator]

                              jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator --> jfchemistry.calculators.torchsim.FairChemCalculator
                                jfchemistry.calculators.base.Calculator --> jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator
                

                jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator --> jfchemistry.calculators.torchsim.FairChemCalculator
                                jfchemistry.calculators.base.Calculator --> jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator
                



              click jfchemistry.calculators.torchsim.FairChemCalculator href "" "jfchemistry.calculators.torchsim.FairChemCalculator"
              click jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator href "" "jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator"
              click jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator href "" "jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator"
              click jfchemistry.calculators.base.Calculator href "" "jfchemistry.calculators.base.Calculator"
            

FairChem Calculator.

ATTRIBUTE DESCRIPTION
name

Name of the calculator (default: "FairChem Calculator").

TYPE: str

model

The FairChem model to use (default: "uma-s-1").

TYPE: model_types

task

The task to use (default: "omol").

TYPE: task_type

compute_stress

Whether to compute the stress (default: False).

TYPE: bool

Source code in jfchemistry/calculators/torchsim/fairchem_calculator.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@dataclass
class FairChemCalculator(
    TorchSimCalculator, MachineLearnedInteratomicPotentialCalculator, MSONable
):
    """FairChem Calculator.

    Attributes:
        name: Name of the calculator (default: "FairChem Calculator").
        model: The FairChem model to use (default: "uma-s-1").
        task: The task to use (default: "omol").
        compute_stress: Whether to compute the stress (default: False).
    """

    name: str = "FairChem Calculator"
    model: model_types = field(
        default="uma-s-1", metadata={"description": "The FairChem model to use"}
    )
    task: task_type = field(default="omol", metadata={"description": "The task to use"})
    compute_stress: bool = field(
        default=False, metadata={"description": "Whether to compute the stress"}
    )

    def _get_model(self) -> FairChemModel:
        """Get the FairChem model."""
        model = FairChemModel(
            model=self.model,
            task_name=self.task,
            device=device(self.device),
            compute_stress=self.compute_stress,
        )
        self._model = model
        return model

    def _get_properties(self, system: SiteCollection) -> Properties:
        """Get the properties of the FairChem model."""
        if not hasattr(self, "_model"):
            self._get_model()
        prop_calculators = {
            10: {"potential_energy": lambda state: state.energy},
            20: {"forces": lambda state: state.forces},
        }

        if self.compute_stress:
            prop_calculators[30] = {"stress": lambda state: state.stress}

        """Get the properties of the FairChem model"""
        final_results = ts.static(
            system=system.to_ase_atoms(),
            model=self._model,
            # we don't want to save any trajectories this time, just get the properties
            trajectory_reporter={"filenames": None, "prop_calculators": prop_calculators},
        )
        forces = final_results[0]["forces"]
        energy = final_results[0]["potential_energy"]
        if self.compute_stress:
            stress = final_results[0]["stress"]
        properties = FairChemProperties(
            atomic=FairChemAtomicProperties(
                forces=AtomicProperty(
                    name="FairChem Forces",
                    value=forces.tolist() * ureg.eV / ureg.angstrom,
                    description=f"Forces predicted by the {self.model} model and {self.task}",
                ),
            ),
            system=FairChemSystemProperties(
                total_energy=SystemProperty(
                    name="Total Energy",
                    value=energy.tolist() * ureg.eV,
                    description=f"Total energy predicted by the {self.model} model and {self.task}",
                ),
                stress=SystemProperty(
                    name="Stress",
                    value=stress.tolist() * ureg.eV / ureg.angstrom**3,
                    description=f"Stress predicted by the {self.model} model and {self.task}",
                )
                if self.compute_stress
                else None,
            ),
        )
        return properties

OrbCalculator dataclass

Bases: TorchSimCalculator, MachineLearnedInteratomicPotentialCalculator, MSONable


              flowchart TD
              jfchemistry.calculators.torchsim.OrbCalculator[OrbCalculator]
              jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator[TorchSimCalculator]
              jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator[MachineLearnedInteratomicPotentialCalculator]
              jfchemistry.calculators.base.Calculator[Calculator]

                              jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator --> jfchemistry.calculators.torchsim.OrbCalculator
                                jfchemistry.calculators.base.Calculator --> jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator
                

                jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator --> jfchemistry.calculators.torchsim.OrbCalculator
                                jfchemistry.calculators.base.Calculator --> jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator
                



              click jfchemistry.calculators.torchsim.OrbCalculator href "" "jfchemistry.calculators.torchsim.OrbCalculator"
              click jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator href "" "jfchemistry.calculators.torchsim.torchsim_calculator.TorchSimCalculator"
              click jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator href "" "jfchemistry.calculators.base.MachineLearnedInteratomicPotentialCalculator"
              click jfchemistry.calculators.base.Calculator href "" "jfchemistry.calculators.base.Calculator"
            

Orb Calculator.

ATTRIBUTE DESCRIPTION
name

Name of the calculator (default: "Orb Calculator").

TYPE: str

model

The Orb model to use (default: "orb_v3_conservative_omol").

TYPE: Literal['orb_v3_conservative_omol', 'orb_v3_direct_omol', 'orb_v3_direct_20_omat', 'orb_v3_direct_20_mpa', 'orb_v3_direct_inf_omat', 'orb_v3_direct_inf_mpa', 'orb_v3_conservative_20_omat', 'orb_v3_conservative_20_mpa', 'orb_v3_conservative_inf_omat', 'orb_v3_conservative_inf_mpa']

device

The device to use for the model (default: "cpu").

TYPE: Literal['cpu', 'cuda']

conservative

Whether to use the conservative model (default: True).

TYPE: bool

precision

The precision to use for the model (default: "float32-high").

TYPE: Literal['float32-high', 'float32-highest', 'float64']

Source code in jfchemistry/calculators/torchsim/orb_calculator.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@dataclass
class OrbCalculator(TorchSimCalculator, MachineLearnedInteratomicPotentialCalculator, MSONable):
    """Orb Calculator.

    Attributes:
        name: Name of the calculator (default: "Orb Calculator").
        model: The Orb model to use (default: "orb_v3_conservative_omol").
        device: The device to use for the model (default: "cpu").
        conservative: Whether to use the conservative model (default: True).
        precision: The precision to use for the model (default: "float32-high").
    """

    name: str = "Orb Calculator"
    model: Literal[
        "orb_v3_conservative_omol",
        "orb_v3_direct_omol",
        "orb_v3_direct_20_omat",
        "orb_v3_direct_20_mpa",
        "orb_v3_direct_inf_omat",
        "orb_v3_direct_inf_mpa",
        "orb_v3_conservative_20_omat",
        "orb_v3_conservative_20_mpa",
        "orb_v3_conservative_inf_omat",
        "orb_v3_conservative_inf_mpa",
    ] = field(default="orb_v3_conservative_omol", metadata={"description": "The ORB model to use"})
    device: Literal["cpu", "cuda"] = field(
        default="cpu", metadata={"description": "The device to use for the model"}
    )
    conservative: bool = field(
        default=True, metadata={"description": "Whether to use the conservative model"}
    )
    precision: Literal["float32-high", "float32-highest", "float64"] = field(
        default="float32-high", metadata={"description": "The precision to use for the model"}
    )
    compile: bool = field(default=True, metadata={"description": "Whether to compile the model"})
    compute_stress: bool = field(
        default=False, metadata={"description": "Whether to compute the stress"}
    )

    def _get_model(self) -> OrbModel:
        """Get the Orb model."""
        orb_model = getattr(pretrained, self.model)(
            device=self.device, precision=self.precision, compile=self.compile
        )
        model = OrbModel(
            model=orb_model,
            device=self.device,
            compute_stress=self.compute_stress,
        )
        self._model = model
        return model

    def _get_properties(self, system: SiteCollection) -> Properties:
        """Get the properties of the Orb model."""
        if not hasattr(self, "_model"):
            self._get_model()
        prop_calculators = {
            10: {"potential_energy": lambda state: state.energy},
            20: {"forces": lambda state: state.forces},
        }

        if self.compute_stress:
            prop_calculators[30] = {"stress": lambda state: state.stress}

        """Get the properties of the Orb model"""
        final_results = ts.static(
            system=system.to_ase_atoms(),
            model=self._model,
            # we don't want to save any trajectories this time, just get the properties
            trajectory_reporter={"filenames": None, "prop_calculators": prop_calculators},
        )
        forces = final_results[0]["forces"]
        energy = final_results[0]["potential_energy"]
        if self.compute_stress:
            stress = final_results[0]["stress"]
        properties = OrbProperties(
            atomic=OrbAtomicProperties(
                forces=AtomicProperty(
                    name="Orb Forces",
                    value=forces.tolist() * ureg.eV / ureg.angstrom,
                    description=f"Forces predicted by the {self.model} model",
                ),
            ),
            system=OrbSystemProperties(
                total_energy=SystemProperty(
                    name="Total Energy",
                    value=energy.tolist() * ureg.eV,
                    description=f"Total energy predicted by the {self.model} model",
                ),
                stress=SystemProperty(
                    name="Stress",
                    value=stress.tolist() * ureg.eV / ureg.angstrom**3,
                    description=f"Stress predicted by the {self.model} model",
                )
                if self.compute_stress
                else None,
            ),
        )
        return properties